diff --git a/docs/lexicon.html b/docs/lexicon.html
index 351d565172..3090834c8b 100644
--- a/docs/lexicon.html
+++ b/docs/lexicon.html
@@ -158,9 +158,12 @@
sorted(seq)seq[,key])
- - returns a copy of the given list with the contents sorted.
+ - returns a copy of seq with the contents
+ sorted. key is a function that is applied
+ to each item before comparison.
diff --git a/rules/builtins.build_defs b/rules/builtins.build_defs
index 7380ef8a6d..5488f9f105 100644
--- a/rules/builtins.build_defs
+++ b/rules/builtins.build_defs
@@ -125,7 +125,7 @@ def glob(include:list|str, exclude:list|str&excludes=[], hidden:bool=CONFIG.BAZE
def package():
pass
-def sorted(seq:list) -> list:
+def sorted(seq:list, key:function=None) -> list:
pass
def reversed(seq:list) -> list:
diff --git a/src/parse/asp/builtins.go b/src/parse/asp/builtins.go
index 7562efd602..8ce3e24680 100644
--- a/src/parse/asp/builtins.go
+++ b/src/parse/asp/builtins.go
@@ -798,10 +798,30 @@ func dictCopy(s *scope, args []pyObject) pyObject {
}
func sorted(s *scope, args []pyObject) pyObject {
- l, ok := args[0].(pyList)
- s.Assert(ok, "unsortable type %s", args[0].Type())
+ l, isList := args[0].(pyList)
+ key, isFunc := args[1].(*pyFunc)
+ s.Assert(isList, "Argument seq must be a list, not %s", args[0].Type())
l = l[:]
- sort.Slice(l, func(i, j int) bool { return s.operator(LessThan, l[i], l[j]).IsTruthy() })
+ if key == nil {
+ sort.Slice(l, func(i, j int) bool {
+ return s.operator(LessThan, l[i], l[j]).IsTruthy()
+ })
+ } else {
+ s.Assert(isFunc, "Argument key must be callable, not %s", args[1].Type())
+ sort.Slice(l, func(i, j int) bool {
+ iKey := key.Call(s, &Call{
+ Arguments: []CallArgument{{
+ Value: Expression{optimised: &optimisedExpression{Constant: l[i]}},
+ }},
+ })
+ jKey := key.Call(s, &Call{
+ Arguments: []CallArgument{{
+ Value: Expression{optimised: &optimisedExpression{Constant: l[j]}},
+ }},
+ })
+ return s.operator(LessThan, iKey, jKey).IsTruthy()
+ })
+ }
return l
}
diff --git a/src/parse/asp/interpreter_test.go b/src/parse/asp/interpreter_test.go
index 578e9b3039..6a8faa3c11 100644
--- a/src/parse/asp/interpreter_test.go
+++ b/src/parse/asp/interpreter_test.go
@@ -169,8 +169,9 @@ func TestInterpreterSlicing(t *testing.T) {
func TestInterpreterSorting(t *testing.T) {
s, err := parseFile("src/parse/asp/test_data/interpreter/sorted.build")
require.NoError(t, err)
- assert.Equal(t, pyList{pyInt(1), pyInt(2), pyInt(3)}, s.Lookup("y"))
// N.B. sorted() sorts in-place, unlike Python's one. We may change that later.
+ assert.Equal(t, pyList{pyInt(1), pyInt(2), pyInt(3)}, s.Lookup("r1"))
+ assert.Equal(t, pyList{pyString("ONE"), pyString("THREE"), pyString("two")}, s.Lookup("r2"))
}
func TestReversed(t *testing.T) {
diff --git a/src/parse/asp/test_data/interpreter/sorted.build b/src/parse/asp/test_data/interpreter/sorted.build
index a6ac1fde4c..29e2d58f33 100644
--- a/src/parse/asp/test_data/interpreter/sorted.build
+++ b/src/parse/asp/test_data/interpreter/sorted.build
@@ -1,2 +1,6 @@
-x = [3, 2, 1]
-y = sorted(x)
+l1 = [3, 2, 1]
+r1 = sorted(l1)
+
+# key parameter test
+l2 = ["ONE", "two", "THREE"]
+r2 = sorted(l2, key=lambda s: s.lower())