Operators, expressions, and dunder methods
Mojo includes a variety of operators for manipulating values of different types. Generally, the operators are equivalent to those found in Python, though many operators also work with additional Mojo types such as SIMD vectors. Additionally, Mojo allows you to define the behavior of most of these operators for your own custom types by implementing special dunder (double underscore) methods.
This document contains the following three sections:
- Operators and expressions discusses Mojo's built-in operators and how they work with commonly used Mojo types.
- Implement operators for custom types describes the dunder methods that you can implement to support using operators with custom structs that you create.
- An example of implementing operators for a custom type shows a progressive example of writing a custom struct with support for several operators.
Operators and expressions
This section lists the operators that Mojo supports, their order or precedence and associativity, and describes how these operators behave with several commonly used built-in types.
Operator precedence and associativity
The table below lists the various Mojo operators, along with their order of precedence and associativity (also referred to as grouping). This table lists operators from the highest precedence to the lowest precedence.
| Operators | Description | Associativity (Grouping) |
|---|---|---|
() | Parenthesized expression | Left to right |
x[index], x[index:index] | Subscripting, slicing | Left to right |
** | Exponentiation | Right to left |
+x, -x, ~x | Positive, negative, bitwise NOT | Right to left |
*, @, /, //, % | Multiplication, matrix multiplication, division, floor division, remainder | Left to right |
+, – | Addition and subtraction | Left to right |
<<, >> | Shifts | Left to right |
& | Bitwise AND | Left to right |
^ | Bitwise XOR | Left to right |
| | Bitwise OR | Left to right |
in, not in, is, is not, <, <=, >, >=, !=, == | Comparisons, membership tests, identity tests | Left to Right |
not x | Boolean NOT | Right to left |
x and y | Boolean AND | Left to right |
x or y | Boolean OR | Left to right |
if-else | Conditional expression | Right to left |
:= | Assignment expression (walrus operator) | Right to left |
Mojo supports the same operators as Python (plus a few extensions), and they have the same precedence levels. For example, the following arithmetic expression evaluates to 40:
5 + 4 * 3 ** 2 - 1It is equivalent to the following parenthesized expression to explicitly control the order of evaluation:
(5 + (4 * (3 ** 2))) - 1Associativity defines how operators of the same precedence level are grouped into expressions. The table indicates whether operators of a given level are left- or right-associative. For example, multiplication and division are left associative, so the following expression results in a value of 3:
3 * 4 / 2 / 2It is equivalent to the following parenthesized expression to explicitly control the order of evaluation:
((3 * 4) / 2) / 2Whereas in the following, exponentiation operators are right associative resulting in a value of 264,144:
4 ** 3 ** 2It is equivalent to the following parenthesized expression to explicitly control the order of evaluation:
4 ** (3 ** 2)Arithmetic and bitwise operators
Numeric types describes the different numeric types provided by the Mojo standard library. The arithmetic and bitwise operators have slightly different behavior depending on the types of values provided.
Int and UInt values
The Int and UInt types represent signed and unsigned integers of the word size of the CPU, typically 64 bits or 32 bits.
The Int and UInt types support all arithmetic operators except matrix multiplication (@), as well as all bitwise and shift operators. If both operands to a binary operator are Int values the result is an Int, if both operands are UInt values the result is a UInt, and if one operand is Int and the other UInt the result is an Int. The one exception for these types is true division, /, which always returns a Float64 type value.
var a_int: Int = -7 var b_int: Int = 4 sum_int = a_int + b_int # Result is type Int print("Int sum:", sum_int) var i_uint: UInt = 9 var j_uint: UInt = 8 sum_uint = i_uint + j_uint # Result is type UInt print("UInt sum:", sum_uint) sum_mixed = a_int + Int(i_uint) # Result is type Int print("Mixed sum:", sum_mixed) quotient_int = a_int / b_int # Result is type Float64 print("Int quotient:", quotient_int) quotient_uint = i_uint / j_uint # Result is type Float64 print("UInt quotient:", quotient_uint)Int sum: -3 UInt sum: 17 Mixed sum: 2 Int quotient: -1.75 UInt quotient: 1.125SIMD values
The Mojo standard library defines the SIMD type to represent a fixed-size array of values that can fit into a processor's register. This allows you to take advantage of single instruction, multiple data operations in hardware to efficiently process multiple values in parallel. SIMD values of a numeric DType support all arithmetic operators except for matrix multiplication (@), though the left shift (<<) and right shift (>>) operators support only integral types. Additionally, SIMD values of an integral or boolean type support all bitwise operators. SIMD values apply the operators in an elementwise fashion, as shown in the following example:
simd1 = SIMD[DType.int32, 4](2, 3, 4, 5) simd2 = SIMD[DType.int32, 4](-1, 2, -3, 4) simd3 = simd1 * simd2 print(simd3)[-2, 6, -12, 20]Scalar values are simply aliases for single-element SIMD vectors, so Float16 is just an alias for SIMD[DType.float16, 1]. Therefore Scalar values support the same set of arithmetic and bitwise operators.
var f1: Float16 = 2.5 var f2: Float16 = -4.0 var f3 = f1 * f2 # Implicitly of type Float16 print(f3)-10.0When using these operators on SIMD values, Mojo requires both to have the same size and DType, and the result is a SIMD of the same size and DType. The operators do not automatically widen lower precision SIMD values to higher precision. This means that the DType of each value must be the same or else the result is a compilation error.
var i8: Int8 = 8 var f64: Float64 = 64.0 result = i8 * f64error: invalid call to '__mul__': failed to infer parameter 'type' of parent struct 'SIMD' result = i8 * f64 ~~~^~~~~If you need to perform an arithmetic or bitwise operator on two SIMD values of different types, you can explicitly convert a value to the desired type either by invoking its cast() method or by passing it as an argument to the constructor of the target type.
For example, to fix the previous example, add an explicit conversion:
var i8: Int8 = 8 var f64: Float64 = 64.0 result = Float64(i8) * f64Here are some more examples of converting SIMD values using both constructors and the cast() method:
simd4 = SIMD[DType.float32, 4](2.2, 3.3, 4.4, 5.5) simd5 = SIMD[DType.int16, 4](-1, 2, -3, 4) simd6 = simd4 * simd5.cast[DType.float32]() # Convert with cast() method print("simd6:", simd6) simd7 = simd5 + SIMD[DType.int16, 4](simd4) # Convert with SIMD constructor print("simd7:", simd7)simd6: [-2.2, 6.6, -13.200001, 22.0] simd7: [1, 5, 1, 9]One exception is that the exponentiation operator, **, is overloaded so that you can specify an Int type exponent. All values in the SIMD are exponentiated to the same power.
base_simd = SIMD[DType.float64, 4](1.1, 2.2, 3.3, 4.4) var power: Int = 2 pow_simd = base_simd ** power # Result is SIMD[DType.float64, 4] print(pow_simd)[1.2100000000000002, 4.8400000000000007, 10.889999999999999, 19.360000000000003]There are three operators related to division:
-
/, the "true division" operator, performs floating point division forSIMDvalues with a floating pointDType. ForSIMDvalues with an integralDType, true division truncates the quotient to an integral result.num_float16 = SIMD[DType.float16, 4](3.5, -3.5, 3.5, -3.5) denom_float16 = SIMD[DType.float16, 4](2.5, 2.5, -2.5, -2.5) num_int32 = SIMD[DType.int32, 4](5, -6, 7, -8) denom_int32 = SIMD[DType.int32, 4](2, 3, -4, -5) # Result is SIMD[DType.float16, 4] true_quotient_float16 = num_float16 / denom_float16 print("True float16 division:", true_quotient_float16) # Result is SIMD[DType.int32, 4] true_quotient_int32 = num_int32 / denom_int32 print("True int32 division:", true_quotient_int32)True float16 division: [1.4003906, -1.4003906, -1.4003906, 1.4003906] True int32 division: [2, -2, -1, 1] -
//, the "floor division" operator, performs division and rounds down the result to the nearest integer. The resultingSIMDis still the same type as the original operands. For example:# Result is SIMD[DType.float16, 4] var floor_quotient_float16 = num_float16 // denom_float16 print("Floor float16 division:", floor_quotient_float16) # Result is SIMD[DType.int32, 4] var floor_quotient_int32 = num_int32 // denom_int32 print("Floor int32 division:", floor_quotient_int32)Floor float16 division: [1.0, -2.0, -2.0, 1.0] Floor int32 division: [2, -2, -2, 1] -
%, the modulo operator, returns the remainder after dividing the numerator by the denominator an integral number of times. The relationship between the//and%operators can be defined asnum == denom * (num // denom) + (num % denom). For example:# Result is SIMD[DType.float16, 4] var remainder_float16 = num_float16 % denom_float16 print("Modulo float16:", remainder_float16) # Result is SIMD[DType.int32, 4] var remainder_int32 = num_int32 % denom_int32 print("Modulo int32:", remainder_int32) print() # Result is SIMD[DType.float16, 4] var result_float16 = denom_float16 * floor_quotient_float16 + remainder_float16 print("Result float16:", result_float16) # Result is SIMD[DType.int32, 4] var result_int32 = denom_int32 * floor_quotient_int32 + remainder_int32 print("Result int32:", result_int32)Modulo float16: [1.0, 1.5, -1.5, -1.0] Modulo int32: [1, 0, -1, -3] Result float16: [3.5, -3.5, 3.5, -3.5] Result int32: [5, -6, 7, -8]
IntLiteral and FloatLiteral values
IntLiteral and FloatLiteral are compile-time, numeric values. When they are used in a compile-time context, they are arbitrary-precision values. When they are used in a run-time context, they are materialized as Int and Float64 type values, respectively.
As an example, the following code causes a compile-time error because the calculated IntLiteral value is too large to store in an Int variable:
alias big_int = (1 << 65) + 123456789 # IntLiteral var too_big_int: Int = big_int print("Result:", too_big_int)note: integer value 36893488147542560021 requires 67 bits to store, but the destination bit width is only 64 bits wideHowever in the following example, taking that same IntLiteral value, dividing by the IntLiteral 10 and then assigning the result to an Int variable compiles and runs successfully, because the final IntLiteral quotient can fit in a 64-bit Int.
alias big_int = (1 << 65) + 123456789 # IntLiteral var not_too_big_int: Int = big_int // 10 print("Result:", not_too_big_int)Result: 3689348814754256002In a compile-time context, IntLiteral and FloatLiteral values support all arithmetic operators except exponentiation (**), and IntLiteral values support all bitwise and shift operators. In a run-time context, materialized IntLiteral values are Int values and therefore support the same operators as Int, and materialized FloatLiteral values are Float64 values and therefore support the same operators as Float64.
Comparison operators
Mojo supports a standard set of comparison operators: ==, !=, <, <=, >, and >=. However their behavior depends on the type of values being compared.
The remainder of this section describes numerical comparison operators. String comparisons are discussed in the String operators. Several other types in the Mojo standard library support various comparison operators, in particular the "equal" and "not equal" comparisons. Consult the API documentation for a type to determine whether any comparison operators are supported.
Bool-returning comparisons
These comparisons return a single Bool value:
-
Int,UInt,IntLiteral, and any type that can be implicitly converted toIntorUIntdo standard numerical comparison with aBoolresult. -
Equality operators (
==and!=) with multi-elementSIMDvalues return aBoolresult using reduction semantics. The comparison isTrueonly if it's true for all corresponding elements. For example:simd8 = SIMD[DType.int32, 4](1, 2, 3, 2) simd9 = SIMD[DType.int32, 4](1, 2, 4, 2) print("simd8 == simd9:", simd8 == simd9) # False (element 2 differs) print("simd8 != simd9:", simd8 != simd9) # True (not all elements equal)simd8 == simd9: False simd8 != simd9: True -
Inequality operators (
<,<=,>,>=) with multi-elementSIMDvalues are not supported. These operators only work with scalar (single-element)SIMDvalues. -
Scalarvalues are simply aliases for single-elementSIMDvectors and support all comparison operators withBoolresults:var float1: Float16 = 12.345 # SIMD[DType.float16, 1] var float2: Float32 = 0.5 # SIMD[DType.float32, 1] result = Float32(float1) > float2 # Result is Bool print(result)True
Elementwise comparisons
For elementwise comparisons that return a SIMD[DType.bool] result, use the comparison methods: eq(), ne(), lt(), le(), gt(), and ge(). These methods work with both SIMD-to-SIMD and SIMD-to-scalar comparisons. Here are examples showing all six elementwise comparison methods:
simd8 = SIMD[DType.int32, 4](1, 2, 3, 2) simd9 = SIMD[DType.int32, 4](1, 2, 4, 2) print("simd8.eq(simd9):", simd8.eq(simd9)) # Equal print("simd8.ne(simd9):", simd8.ne(simd9)) # Not equal print("simd8.lt(simd9):", simd8.lt(simd9)) # Less than print("simd8.le(simd9):", simd8.le(simd9)) # Less than or equal print("simd8.gt(simd9):", simd8.gt(simd9)) # Greater than print("simd8.ge(simd9):", simd8.ge(simd9)) # Greater than or equalsimd8.eq(simd9): [True, True, False, True] simd8.ne(simd9): [False, False, True, False] simd8.lt(simd9): [False, False, True, False] simd8.le(simd9): [True, True, True, True] simd8.gt(simd9): [False, False, False, False] simd8.ge(simd9): [True, True, False, True]You can also use these methods for SIMD-to-scalar comparisons:
simd4 = SIMD[DType.int16, 4](-1, 2, -3, 4) simd5 = simd4.gt(2) # SIMD[DType.bool, 4] print("simd4.gt(2):", simd5) simd6 = SIMD[DType.float32, 4](1.1, -2.2, 3.3, -4.4) simd7 = simd6.gt(0.5) # SIMD[DType.bool, 4] print("simd6.gt(0.5):", simd7)simd4.gt(2): [False, False, False, True] simd6.gt(0.5): [True, False, True, False]Use elementwise comparison methods when you need to compare each element individually and work with the resulting boolean mask for further processing.
String operators
As discussed in Strings, the String type represents a mutable string value. In contrast, the StringLiteral type represents a literal string that is embedded into your compiled program, but at run-time it materializes to a String, allowing you to mutate it:
message = "Hello" # type = String alias name = " Pat" # type = StringLiteral greeting = " good Day!" # type = String # Mutate the original `message` String message += name message += greeting print(message)Hello Pat good day!This means that StringLiteral values can be intermixed with String values in any runtime expression without having to convert between types.
String concatenation
The + operator performs string concatenation. The StringLiteral type supports compile-time string concatenation.
alias last_name = "Curie" # Compile-time StringLiteral alias alias marie = "Marie " + last_name print(marie) # Compile-time concatenation before materializing to a run-time `String` pierre = "Pierre " + last_name print(pierre)String replication
The * operator replicates a String a specified number of times. For example:
var str1: String = "la" str2 = str1 * 5 print(str2)lalalalalaStringLiteral supports the * operator for both compile-time and run-time string replication. The following examples perform compile-time string replication resulting in StringLiteral values:
alias divider1 = "=" * 40 alias symbol = "#" alias divider2 = symbol * 40 # You must define the following function using `fn` because an alias # initializer cannot call a function that can potentially raise an error. fn generate_divider(char: String, repeat: Int) -> String: return char * repeat alias divider3 = generate_divider("~", 40) # Evaluated at compile-time print(divider1) print(divider2) print(divider3)======================================== ######################################## ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~In contrast, the following examples perform run-time string replication resulting in String values:
repeat = 40 div1 = "^" * repeat print(div1) print("_" * repeat)^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ________________________________________String comparison
String and StringLiteral values can be compared using standard lexicographical ordering, producing a Bool. For example, "Zebra" is treated as less than "ant" because upper case letters occur before lower case letters in the character encoding.
var animal: String = "bird" is_cat_eq = "cat" == animal print('Is "cat" equal to "{}"?'.format(animal), is_cat_eq) is_cat_ne = "cat" != animal print('Is "cat" not equal to "{}"?'.format(animal), is_cat_ne) is_bird_eq = "bird" == animal print('Is "bird" equal to "{}"?'.format(animal), is_bird_eq) is_cat_gt = "CAT" > animal print('Is "CAT" greater than "{}"?'.format(animal), is_cat_gt) is_ge_cat = animal >= "CAT" print('Is "{}" greater than or equal to "CAT"?'.format(animal), is_ge_cat)Is "cat" equal to "bird"? False Is "cat" not equal to "bird"? True Is "bird" equal to "bird"? True Is "CAT" greater than "bird"? False Is "bird" greater than or equal to "CAT"? TrueSubstring testing
String, StringLiteral, and StringSlice support using the in operator to produce a Bool result indicating whether a given substring appears within another string. The operator is overloaded so that you can use any combination of String and StringLiteral for both the substring and the string to test.
var food: String = "peanut butter" if "nut" in food: print("It contains a nut") else: print("It doesn't contain a nut")It contains a nutString indexing and slicing
String, StringLiteral, and StringSlice allow you to use indexing to return a single character. Character positions are identified with a zero-based index starting from the first character. You can also specify a negative index to count backwards from the end of the string, with the last character identified by index -1. Specifying an index beyond the bounds of the string results in a run-time error.
var alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" # String type value print(alphabet[0], alphabet[-1]) # The following would produce a run-time error # print(alphabet[45])A ZThe String and StringSlice types—but not the StringLiteral type—also support slices to return a substring from the original String. Providing a slice in the form [start:end] returns a substring starting with the character index specified by start and continuing up to but not including the character at index end. You can use positive or negative indexing for both the start and end values. Omitting start is the same as specifying 0, and omitting end is the same as specifying the length of the string.
var alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" # String type value print(alphabet[1:4]) # The 2nd through 4th characters print(alphabet[:6]) # The first 6 characters print(alphabet[-6:]) # The last 6 charactersBCD ABCDEF UVWXYZYou can also specify a slice with a step value, as in [start:end:step] indicating the increment between subsequent indices of the slide. (This is also sometimes referred to as a "stride.") If you provide a negative value for step, characters are selected in reverse order starting with start but then with decreasing index values up to but not including end.
print(alphabet[1:6:2]) # The 2nd, 4th, and 6th characters print(alphabet[-1:-4:-1]) # The last 3 characters in reverse order print(alphabet[::-1]) # The entire string reversedBDF ZYX ZYXWVUTSRQPONMLKJIHGFEDCBAIn-place assignment operators
Mutable types that support binary arithmetic, bitwise, and shift operators typically support equivalent in-place assignment operators. That means that for a type that supports the + operator, the following two statements are essentially equivalent:
a = a + b a += bHowever there is a subtle difference between the two. In the first example, the expression a + b produces a new value, which is then assigned to a. In contrast, the second example does an in-place modification of the value currently assigned to a. For register-passable types, the compiled results might be equivalent at run-time. But for a memory-only type, the first example allocates storage for the result of a + b and then assigns the value to the variable, whereas the second example can do an in-place modification of the existing value.
Assignment expressions
The "walrus" operator, :=, allows you to assign a value to a variable within an expression. The value provided is both assigned to the variable and becomes the result of the expression. This often can simplify conditional or looping logic. For example, consider the following prompting loop:
while True: name = input("Enter a name or 'quit' to exit: ") if name == "quit": break print("Hello,", name)Enter a name or 'quit' to exit: Coco Hello, Coco Enter a name or 'quit' to exit: Vivienne Hello, Vivienne Enter a name or 'quit' to exit: quitUsing the walrus operator, you can implement the same behavior like this:
while (name := input("Enter a name or 'quit' to exit: ")) != "quit": print("Hello,", name)Enter a name or 'quit' to exit: Donna Hello, Donna Enter a name or 'quit' to exit: Vera Hello, Vera Enter a name or 'quit' to exit: quitType merging
When an expression involves values of different types, Mojo needs to statically determine the return type of the expression. This process is called type merging. By default, Mojo determines type merging based on implicit conversions. Individual structs can also define custom type merging behavior.
The following code demonstrates type merging based on implicit conversions:
list = [0.5, 1, 2] for value in list: print(value)0.5 1.0 2.0Here, the list literal includes both float and integer literals, which materialize as Float64 and Int, respectively. Since Int can be implicitly converted to Float64, the result is a List[Float64].
Here's an example of where type merging fails:
a: Int = 0 b: String = "Hello" c = a if a > 0 else b # Error: value of type 'Int' is not compatible with # value of type 'String'mojoIn this case, Int can't be implicitly converted to a String, and String can't be implicitly converted to an Int, so type merging fails. This is the correct result: there's no way for Mojo to know what type you want c to take. You can fix this by adding an explicit conversion:
c = String(a) if a > 0 else bIndividual structs can define custom type merging logic by defining a __merge_with__() dunder method. For example:
@fieldwise_init struct MyType(Movable, Copyable): var val: Int def __bool__(self) -> Bool: return self.val > 0 def __merge_with__[other_type: type_of(Int)](self) -> Int: return Int(self.val) def main(): i = 0 m = MyType(9) print(i if i > 0 else m) # prints "9"If either type in the expression defines a custom __merge_with__() dunder for merging with the other type, this type takes precedence over any implicit conversions. (Note that the result type doesn't have to be either of the input types, it could be a third type.)
A type can declare multiple __merge_with__() overrides for different types.
At a high level, the logic for merging two types goes like this:
- Does either type define a
__merge_with__()method for the other type? If so, the returned value determines the target type.- If both types define a
__merge_with__()method for the other type, the two methods must both return the same type, or the conversion fails. - Both types must be implicitly convertible to the target type (a type is always implicitly convertible to itself).
- If both types define a
- Is either type implicitly convertible to the other type?
- If only one type is implicitly convertible to the other type, convert it.
- If both types are convertible to the other type, the conversion is ambiguous, and it fails.
For more background on type merging and the __merge_with__() dunder, see the proposal, Customizable Type Merging in Mojo.
Implement operators for custom types
When you create a custom struct, Mojo allows you to define the behavior of many of the built-in operators for that type by implementing special dunder (double underscore) methods. This section lists the dunder methods associated with the operators and briefly describes the requirements for implementing them.
Unary operator dunder methods
A unary operator invokes an associated dunder method on the value to which it applies. The supported unary operators and their corresponding methods are shown in the table below.
| Operator | Dunder method |
|---|---|
+ positive | __pos__() |
- negative | __neg__() |
~ bitwise NOT | __invert__() |
For each of these methods that you decide to implement, you should return either the original value if unchanged, or a new value representing the result of the operator. For example, you could implement the - negative operator for a MyInt struct like this:
@fieldwise_init struct MyInt: var value: Int def __neg__(self) -> Self: return Self(-self.value)Binary arithmetic, shift, and bitwise operator dunder methods
When you have a binary expression like a + b, there are two possible dunder methods that could be invoked.
Mojo first determines whether the left-hand side value (a in this example) has a "normal" version of the + operator's dunder method defined that accepts a value of the right-hand side's type. If so, it then invokes that method on the left-hand side value and passes the right-hand side value as an argument.
If Mojo doesn't find a matching "normal" dunder method on the left-hand side value, it then checks whether the right-hand side value has a "reflected" (sometimes referred to as "reversed") version of the + operator's dunder method defined that accepts a value of the left-hand side's type. If so, it then invokes that method on the right-hand side value and passes the left-hand side value as an argument.
For both the normal and the reflected versions, the dunder method should return a new value representing the result of the operator.
Additionally, there are dunder methods corresponding to the in-place assignment versions of the operators. These methods receive the right-hand side value as an argument and the methods should modify the existing left-hand side value to reflect the result of the operator.
The table below lists the various binary arithmetic, shift, and bitwise operators and their corresponding normal, reflected, and in-place dunder methods.
| Operator | Normal | Reflected | In-place |
|---|---|---|---|
+ addition | __add__() | __radd__() | __iadd__() |
- subtraction | __sub__() | __rsub__() | __isub__() |
* multiplication | __mul__() | __rmul__() | __imul__() |
/ division | __truediv__() | __rtruediv__() | __itruediv__() |
// floor division | __floordiv__() | __rfloordiv__() | __ifloordiv__() |
% modulus/remainder | __mod__() | __rmod__() | __imod__() |
** exponentiation | __pow__() | __rpow__() | __ipow__() |
@ matrix multiplication | __matmul__() | __rmatmul__() | __imatmul__() |
<< left shift | __lshift__() | __rlshift__() | __ilshift__() |
>> right shift | __rshift__() | __rrshift__() | __irshift__() |
& bitwise AND | __and__() | __rand__() | __iand__() |
| bitwise OR | __or__() | __ror__() | __ior__() |
^ bitwise XOR | __xor__() | __rxor__() | __ixor__() |
As an example, consider implementing support for all of the + operator dunder methods for a custom MyInt struct. This shows supporting adding two MyInt instances as well as adding a MyInt and an Int. We can support the case of having the Int as the right-hand side argument by overloaded the definition of __add__(). But to support the case of having the Int as the left-hand side argument, we need to implement an __radd__() method, because the built-in Int type doesn't have an __add__() method that supports our custom MyInt type.
@fieldwise_init struct MyInt: var value: Int def __add__(self, rhs: MyInt) -> Self: return MyInt(self.value + rhs.value) def __add__(self, rhs: Int) -> Self: return MyInt(self.value + rhs) def __radd__(self, lhs: Int) -> Self: return MyInt(self.value + lhs) def __iadd__(mut self, rhs: MyInt) -> None: self.value += rhs.value def __iadd__(mut self, rhs: Int) -> None: self.value += rhsComparison operator dunder methods
When you have a comparison expression like a < b, Mojo invokes as associated dunder method on the left-hand side value and passes the right-hand side value as an argument. Mojo doesn't support "reflected" versions of these dunder methods because you should only compare values of the same type. The comparison dunder methods must return a Bool result representing the result of the comparison.
There are two traits associated with the comparison dunder methods. A type that implements the Comparable trait defines all of the comparison methods, and authors are required to implement at least the "less-than" and "equal" methods, since the trait provides defaults for the rest. However, some types don't have a natural ordering (for example, complex numbers). For those types you can decide to implement the Equatable trait, which defines only the "equal" and "not equal" comparison methods, with "equal" being required to implement by conforming structs.
The supported comparison operators and their corresponding methods are shown in the table below.
| Operator | Dunder method |
|---|---|
== equal | __eq__() |
!= not equal | __ne__() |
< less than | __lt__() |
<= less than or equal | __le__() |
> greater than | __gt__() |
>= greater than or equal | __ge__() |
As an example, consider implementing support for all of the comparison operator dunder methods for a custom MyInt struct by relying on the default implementations provided by the Comparable (and transitively the Equatable) traits.
@fieldwise_init struct MyInt(Comparable): var value: Int fn __eq__(self, rhs: MyInt) -> Bool: return self.value == rhs.value fn __lt__(self, rhs: MyInt) -> Bool: return self.value < rhs.value # `__ne__`, `__le__`, `__gt__`, and `__ge__` have default implementations.Membership operator dunder methods
The in and not in operators depend on a type implementing the __contains__() dunder method. Typically only collection types (such as List, Dict, and Set) implement this method. It should accept the right-hand side value as an argument and return a Bool indicating whether the value is present in the collection or not.
Subscript and slicing dunder methods
Subscripting and slicing typically apply only to sequential collection types, like List and String. Subscripting references a single element of a collection or a dimension of a multi-dimensional container, whereas slicing refers to a range of values. A type supports both subscripting and slicing by implementing the __getitem__() method for retrieving values and the __setitem__() method for setting values.
Subscripting
In the simple case of a one-dimensional sequence, the __getitem__() and __setitem__() methods should have signatures similar to this:
struct MySeq[type: Copyable & Movable]: fn __getitem__(self, idx: Int) -> type: # Return element at the given index ... fn __setitem__(mut self, idx: Int, value: type): # Assign the element at the given index the provided valueIt's also possible to support multi-dimensional collections, in which case you can implement both __getitem__() and __setitem__() methods to accept multiple index arguments—or even variadic index arguments for arbitrary—dimension collections.
struct MySeq[type: Copyable & Movable]: # 2-dimension support fn __getitem__(self, x_idx: Int, y_idx: Int) -> type: ... # Arbitrary-dimension support fn __getitem__(self, *indices: Int) -> type: ...Slicing
You provide slicing support for a collection type also by implementing __getitem__() and __setitem__() methods. But for slicing, instead of accepting an Int index (or indices, in the case of a multi-dimensional collection) you implement to methods to accept a Slice (or multiple Slices in the case of a multi-dimensional collection).
struct MySeq[type: Copyable & Movable]: # Return a new MySeq with a subset of elements fn __getitem__(self, span: Slice) -> Self: ... A Slice contains three fields:
start(Optional[Int]): The starting index of the sliceend(Optional[Int]): The ending index of the slicestep(Optional[Int]): The step increment value of the slice.
Because the start, end, and step values are all optional when using slice syntax, they are represented as Optional[Int] values in the Slice. And if present, the index values might be negative representing a relative position from the end of the sequence. As a convenience, Slice provides an indices() method that accepts a length value and returns a 3-tuple of "normalized" start, end, and step values for the given length, all represented as non-negative values. You can then use these normalized values to determine the corresponding elements of your collection being referenced.
struct MySeq[type: Copyable & Movable]: var size: Int # Return a new MySeq with a subset of elements fn __getitem__(self, span: Slice) -> Self: var start: Int var end: Int var step: Int start, end, step = span.indices(self.size) ... An example of implementing operators for a custom type
As an example of implementing operators for a custom Mojo type, let's create a Complex struct to represent a single complex number, with both the real and imaginary components stored as Float64 values. We'll implement most of the arithmetic operators, the associated in-place assignment operators, the equality comparison operators, and a few additional convenience methods to support operations like printing complex values. We'll also allow mixing Complex and Float64 values in arithmetic expressions to produce a Complex result.
This example builds our Complex struct incrementally. You can also find the complete example in the public GitHub repo.
Implement lifecycle methods
Our Complex struct is an example of a simple value type consisting of trivial numeric fields and requiring no special constructor or destructor behaviors. This means we can use the @register_passable("trivial") decorator, which declares that the type can be trivially copied, moved, and destroyed—and doesn't need a copy constructor, move constructor, or destructor.
For the time being, we'll also use the @fieldwise_init decorator to automatically implement a field-wise initializer (a constructor with arguments for each field).
@fieldwise_init @register_passable("trivial") struct Complex: var re: Float64 var im: Float64This definition is enough for us to create Complex instances and access their real and imaginary fields.
c1 = Complex(-1.2, 6.5) print("c1: Real: {}; Imaginary: {}".format(c1.re, c1.im))c1: Real: -1.2; Imaginary: 6.5As a convenience, let's add an explicit constructor to handle the case of creating a Complex instance with an imaginary component of 0.
@register_passable("trivial") struct Complex(): var re: Float64 var im: Float64 fn __init__(out self, re: Float64, im: Float64 = 0.0): self.re = re self.im = imSince this constructor also handles creating a Complex instance with both real and imaginary components, we don't need the @fieldwise_init decorator anymore.
Now we can create a Complex instance and provide just a real component.
c2 = Complex(3.14159) print("c2: Real: {}; Imaginary: {}".format(c2.re, c2.im))c2: Real: 3.1415899999999999; Imaginary: 0.0Implement the Writable and Stringable traits
To make it simpler to print Complex values, let's implement the Writable trait. While we're at it, let's also implement the Stringable trait so that we can use the String() constructor to generate a String representation of a Complex value. You can find out more about these traits and their associated methods in The Stringable, Representable, and Writable traits.
@register_passable("trivial") struct Complex( Writable, Stringable, ): # ... fn __str__(self) -> String: return String.write(self) fn write_to(self, mut writer: Some[Writer]): writer.write("(", self.re) if self.im < 0: writer.write(" - ", -self.im) else: writer.write(" + ", self.im) writer.write("i)")Now we can print a Complex value directly, and we can explicitly generate a String representation by passing a Complex value to String() which constructs a new String from all the arguments passed to it.
c3 = Complex(3.14159, -2.71828) print("c3 =", c3) var msg = String("The value is: ", c3) print(msg)c3 = (3.1415899999999999 - 2.71828i) The value is: (3.1415899999999999 - 2.71828i)Implement basic indexing
Indexing usually is supported only by collection types. But as an example, let's implement support for accessing the real component as index 0 and the imaginary component as index 1. We'll not implement slicing or variadic assignment for this example.
# ... def __getitem__(self, idx: Int) -> Float64: if idx == 0: return self.re elif idx == 1: return self.im else: raise "index out of bounds" def __setitem__(mut self, idx: Int, value: Float64) -> None: if idx == 0: self.re = value elif idx == 1: self.im = value else: raise "index out of bounds"Now let's try getting and setting the real and imaginary components of a Complex value using indexing.
c2 = Complex(3.14159) print("c2[0]: {}; c2[1]: {}".format(c2[0], c2[1])) c2[0] = 2.71828 c2[1] = 42 print("c2[0] = 2.71828; c2[1] = 42; c2:", c2)c2[0]: 3.1415899999999999; c2[1]: 0.0 c2[0] = 2.71828; c2[1] = 42; c2: (2.71828 + 42.0i)Implement arithmetic operators
Now let's implement the dunder methods that allow us to perform arithmetic operations on Complex values. (Refer to the Wikipedia page on complex numbers for a more in-depth explanation of the formulas for these operators.)
Implement basic operators for Complex values
The unary + operator simply returns the original value, whereas the unary - operator returns a new Complex value with the real and imaginary components negated.
# ... def __pos__(self) -> Self: return self def __neg__(self) -> Self: return Self(-self.re, -self.im)Let's test these out by printing the result of applying each operator.
c1 = Complex(-1.2, 6.5) print("+c1:", +c1) print("-c1:", -c1)+c1: (-1.2 + 6.5i) -c1: (1.2 - 6.5i)Next we'll implement the basic binary operators: +, -, *, and /. Dividing complex numbers is a bit tricky, so we'll also define a helper method called norm() to calculate the Euclidean norm of a Complex instance, which can also be useful for other types of analysis with complex numbers.
For all of these dunder methods, the left-hand side operand is self and the right-hand side operand is passed as an argument. We return a new Complex value representing the result.
from math import sqrt # ... def __add__(self, rhs: Self) -> Self: return Self(self.re + rhs.re, self.im + rhs.im) def __sub__(self, rhs: Self) -> Self: return Self(self.re - rhs.re, self.im - rhs.im) def __mul__(self, rhs: Self) -> Self: return Self( self.re * rhs.re - self.im * rhs.im, self.re * rhs.im + self.im * rhs.re ) def __truediv__(self, rhs: Self) -> Self: denom = rhs.squared_norm() return Self( (self.re * rhs.re + self.im * rhs.im) / denom, (self.im * rhs.re - self.re * rhs.im) / denom ) def squared_norm(self) -> Float64: return self.re * self.re + self.im * self.im def norm(self) -> Float64: return sqrt(self.squared_norm())Now we can try them out.
c1 = Complex(-1.2, 6.5) c3 = Complex(3.14159, -2.71828) print("c1 + c3 =", c1 + c3) print("c1 - c3 =", c1 - c3) print("c1 * c3 =", c1 * c3) print("c1 / c3 =", c1 / c3)c1 + c3 = (1.9415899999999999 + 3.78172i) c1 - c3 = (-4.3415900000000001 + 9.21828i) c1 * c3 = (13.898912000000001 + 23.682270999999997i) c1 / c3 = (-1.2422030701265261 + 0.99419218883955773i)Implement overloaded arithmetic operators for Float64 values
Our initial set of binary arithmetic operators work fine if both operands are Complex instances. But if we have a Float64 value representing just a real value, we'd first need to use it to create a Complex value before we could add, subtract, multiply, or divide it with another Complex value. If we think that this will be a common use case, it makes sense to overload our arithmetic methods to accept a Float64 as the second operand.
For the case where we have complex1 + float1, we can just create an overloaded definition of __add__(). But what about the case of float1 + complex1? By default, when Mojo encounters a + operator it tries to invoke the __add__() method of the left-hand operand, but the built-in Float64 type doesn't implement support for addition with a Complex value. This is an example where we need to implement the __radd__() method on the Complex type. When Mojo can't find an __add__(self, rhs: Complex) -> Complex method defined on Float64, it uses the __radd__(self, lhs: Float64) -> Complex method defined on Complex.
So we can support arithmetic operations on Complex and Float64 values by implementing the following eight methods.
# ... def __add__(self, rhs: Float64) -> Self: return Self(self.re + rhs, self.im) def __radd__(self, lhs: Float64) -> Self: return Self(self.re + lhs, self.im) def __sub__(self, rhs: Float64) -> Self: return Self(self.re - rhs, self.im) def __rsub__(self, lhs: Float64) -> Self: return Self(lhs - self.re, -self.im) def __mul__(self, rhs: Float64) -> Self: return Self(self.re * rhs, self.im * rhs) def __rmul__(self, lhs: Float64) -> Self: return Self(lhs * self.re, lhs * self.im) def __truediv__(self, rhs: Float64) -> Self: return Self(self.re / rhs, self.im / rhs) def __rtruediv__(self, lhs: Float64) -> Self: denom = self.squared_norm() return Self( (lhs * self.re) / denom, (-lhs * self.im) / denom )Let's see them in action.
c1 = Complex(-1.2, 6.5) f1 = 2.5 print("c1 + f1 =", c1 + f1) print("f1 + c1 =", f1 + c1) print("c1 - f1 =", c1 - f1) print("f1 - c1 =", f1 - c1) print("c1 * f1 =", c1 * f1) print("f1 * c1 =", f1 * c1) print("c1 / f1 =", c1 / f1) print("f1 / c1 =", f1 / c1)c1 + f1 = (1.3 + 6.5i) f1 + c1 = (1.3 + 6.5i) c1 - f1 = (-3.7000000000000002 + 6.5i) f1 - c1 = (3.7000000000000002 - 6.5i) c1 * f1 = (-3.0 + 16.25i) f1 * c1 = (-3.0 + 16.25i) c1 / f1 = (-0.47999999999999998 + 2.6000000000000001i) f1 / c1 = (-0.068665598535133904 - 0.37193865873197529i)Implement in-place assignment operators
Now let's implement support for the in-place assignment operators: +=, -=, *=, and /=. These modify the original value, so we need to mark self as being an mut argument and update the re and im fields instead of returning a new Complex instance. And once again, we'll overload the definitions to support both a Complex and a Float64 operand.
# ... def __iadd__(mut self, rhs: Self) -> None: self.re += rhs.re self.im += rhs.im def __iadd__(mut self, rhs: Float64) -> None: self.re += rhs def __isub__(mut self, rhs: Self) -> None: self.re -= rhs.re self.im -= rhs.im def __isub__(mut self, rhs: Float64) -> None: self.re -= rhs def __imul__(mut self, rhs: Self) -> None: new_re = self.re * rhs.re - self.im * rhs.im new_im = self.re * rhs.im + self.im * rhs.re self.re = new_re self.im = new_im def __imul__(mut self, rhs: Float64) -> None: self.re *= rhs self.im *= rhs def __itruediv__(mut self, rhs: Self) -> None: denom = rhs.squared_norm() new_re = (self.re * rhs.re + self.im * rhs.im) / denom new_im = (self.im * rhs.re - self.re * rhs.im) / denom self.re = new_re self.im = new_im def __itruediv__(mut self, rhs: Float64) -> None: self.re /= rhs self.im /= rhsAnd now to try them out.
c4 = Complex(-1, -1) print("c4 =", c4) c4 += Complex(0.5, -0.5) print("c4 += Complex(0.5, -0.5) =>", c4) c4 += 2.75 print("c4 += 2.75 =>", c4) c4 -= Complex(0.25, 1.5) print("c4 -= Complex(0.25, 1.5) =>", c4) c4 -= 3 print("c4 -= 3 =>", c4) c4 *= Complex(-3.0, 2.0) print("c4 *= Complex(-3.0, 2.0) =>", c4) c4 *= 0.75 print("c4 *= 0.75 =>", c4) c4 /= Complex(1.25, 2.0) print("c4 /= Complex(1.25, 2.0) =>", c4) c4 /= 2.0 print("c4 /= 2.0 =>", c4)c4 = (-1.0 - 1.0i) c4 += Complex(0.5, -0.5) => (-0.5 - 1.5i) c4 += 2.75 => (2.25 - 1.5i) c4 -= Complex(0.25, 1.5) => (2.0 - 3.0i) c4 -= 3 => (-1.0 - 3.0i) c4 *= Complex(-3.0, 2.0) => (9.0 + 7.0i) c4 *= 0.75 => (6.75 + 5.25i) c4 /= Complex(1.25, 2.0) => (3.404494382022472 - 1.247191011235955i) c4 /= 2.0 => (1.702247191011236 - 0.6235955056179775i)Implement equality operators
The field of complex numbers is not an ordered field, so it doesn't make sense for us to implement the Comparable trait and the >, >=, <, and <= operators. However, we can implement the Equatable trait and the == and != operators. (Of course, this suffers the same limitation of comparing floating point numbers for equality because of the limited precision of representing floating point numbers when performing arithmetic operations. But we'll go ahead and implement the operators for completeness.)
struct Complex( Equatable, Formattable, Stringable, ): # ... fn __eq__(self, other: Self) -> Bool: return self.re == other.re and self.im == other.im fn __ne__(self, other: Self) -> Bool: return self.re != other.re or self.im != other.imAnd now to try them out.
c1 = Complex(-1.2, 6.5) c3 = Complex(3.14159, -2.71828) c5 = Complex(-1.2, 6.5) if c1 == c5: print("c1 is equal to c5") else: print("c1 is not equal to c5") if c1 != c3: print("c1 is not equal to c3") else: print("c1 is equal to c3")c1 is equal to c5 c1 is not equal to c3Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!