diff --git a/docs/lexicon.html b/docs/lexicon.html index 351d565172..3090834c8b 100644 --- a/docs/lexicon.html +++ b/docs/lexicon.html @@ -158,9 +158,12 @@

sorted(seq)seq[,key]) - - returns a copy of the given list with the contents sorted. + - returns a copy of seq with the contents + sorted. key is a function that is applied + to each item before comparison.
  • diff --git a/rules/builtins.build_defs b/rules/builtins.build_defs index 7380ef8a6d..5488f9f105 100644 --- a/rules/builtins.build_defs +++ b/rules/builtins.build_defs @@ -125,7 +125,7 @@ def glob(include:list|str, exclude:list|str&excludes=[], hidden:bool=CONFIG.BAZE def package(): pass -def sorted(seq:list) -> list: +def sorted(seq:list, key:function=None) -> list: pass def reversed(seq:list) -> list: diff --git a/src/parse/asp/builtins.go b/src/parse/asp/builtins.go index 7562efd602..8ce3e24680 100644 --- a/src/parse/asp/builtins.go +++ b/src/parse/asp/builtins.go @@ -798,10 +798,30 @@ func dictCopy(s *scope, args []pyObject) pyObject { } func sorted(s *scope, args []pyObject) pyObject { - l, ok := args[0].(pyList) - s.Assert(ok, "unsortable type %s", args[0].Type()) + l, isList := args[0].(pyList) + key, isFunc := args[1].(*pyFunc) + s.Assert(isList, "Argument seq must be a list, not %s", args[0].Type()) l = l[:] - sort.Slice(l, func(i, j int) bool { return s.operator(LessThan, l[i], l[j]).IsTruthy() }) + if key == nil { + sort.Slice(l, func(i, j int) bool { + return s.operator(LessThan, l[i], l[j]).IsTruthy() + }) + } else { + s.Assert(isFunc, "Argument key must be callable, not %s", args[1].Type()) + sort.Slice(l, func(i, j int) bool { + iKey := key.Call(s, &Call{ + Arguments: []CallArgument{{ + Value: Expression{optimised: &optimisedExpression{Constant: l[i]}}, + }}, + }) + jKey := key.Call(s, &Call{ + Arguments: []CallArgument{{ + Value: Expression{optimised: &optimisedExpression{Constant: l[j]}}, + }}, + }) + return s.operator(LessThan, iKey, jKey).IsTruthy() + }) + } return l } diff --git a/src/parse/asp/interpreter_test.go b/src/parse/asp/interpreter_test.go index 578e9b3039..6a8faa3c11 100644 --- a/src/parse/asp/interpreter_test.go +++ b/src/parse/asp/interpreter_test.go @@ -169,8 +169,9 @@ func TestInterpreterSlicing(t *testing.T) { func TestInterpreterSorting(t *testing.T) { s, err := parseFile("src/parse/asp/test_data/interpreter/sorted.build") require.NoError(t, err) - assert.Equal(t, pyList{pyInt(1), pyInt(2), pyInt(3)}, s.Lookup("y")) // N.B. sorted() sorts in-place, unlike Python's one. We may change that later. + assert.Equal(t, pyList{pyInt(1), pyInt(2), pyInt(3)}, s.Lookup("r1")) + assert.Equal(t, pyList{pyString("ONE"), pyString("THREE"), pyString("two")}, s.Lookup("r2")) } func TestReversed(t *testing.T) { diff --git a/src/parse/asp/test_data/interpreter/sorted.build b/src/parse/asp/test_data/interpreter/sorted.build index a6ac1fde4c..29e2d58f33 100644 --- a/src/parse/asp/test_data/interpreter/sorted.build +++ b/src/parse/asp/test_data/interpreter/sorted.build @@ -1,2 +1,6 @@ -x = [3, 2, 1] -y = sorted(x) +l1 = [3, 2, 1] +r1 = sorted(l1) + +# key parameter test +l2 = ["ONE", "two", "THREE"] +r2 = sorted(l2, key=lambda s: s.lower())