Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ jobs:
- '1.6'
- '1.9'
- '1.10'
- '1.12'
- 'nightly'
os:
- ubuntu-latest
Expand Down
9 changes: 5 additions & 4 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/Manifest.toml
/docs/Manifest.toml
/Manifest*.toml
/docs/Manifest*.toml
/docs/build/
*.code-workspace
Manifest.toml
LocalPreferences.toml
Manifest*.toml
LocalPreferences.toml
.vscode/
36 changes: 20 additions & 16 deletions src/define_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,34 @@ _propertynames(T::Type) = Base.fieldnames(T)
_propertynames(x) = Base.propertynames(x)

@generated function validate_properties(::Type{S}, delegated_fieldnames::Type{T}) where {S, T <: Tuple}
output = Expr(:block, :(unique_properties = Set{Symbol}($Base.fieldnames($S))))
output_args = Any[:(local unique_properties = Set{Symbol}($Base.fieldnames($S)))]
recursive = fieldtype(T, 1) === Val{:recursive}
for i in 2:fieldcount(T)
key = fieldtype(T, i)
props = Symbol(key, :_properties)
child_T = fieldtype(S, key)
if recursive
push!(output.args, :($props = Set{Symbol}( $_propertynames($child_T))))
push!(output_args, :(local $props = Set{Symbol}( $_propertynames($child_T))))
else
push!(output.args, :($props = Set{Symbol}( $fieldnames($child_T))))
push!(output_args, :(local $props = Set{Symbol}( $fieldnames($child_T))))
end

push!( output.args, :( diff = $intersect(unique_properties, $(props)) ), :(!isempty(diff) && error("Duplicate properties `$(sort(collect(diff)))` found for type $($(S)) in child `$($(QuoteNode(key)))::$($(child_T))` ")), :($union!(unique_properties, $props)))
push!( output_args, :( diff = $intersect(unique_properties, $(props)) ), :(!isempty(diff) && error("Duplicate properties `$(sort(collect(diff)))` found for type $($(S)) in child `$($(QuoteNode(key)))::$($(child_T))` ")), :($union!(unique_properties, $props)))
end
push!(output.args, :(return nothing))
return output
push!(output_args, :(return nothing))
return Expr(:block, output_args...)
end

@generated function _propertynames(::Type{S}, delegated_fields::Type{T}) where {S, T <: Tuple}
Base.isstructtype(S) || error("$S is not a struct type")
output = Expr(:tuple, QuoteNode.(fieldnames(S))...)
output_args = Any[QuoteNode.(fieldnames(S))...]
recursive = fieldtype(T, 1) === Val{:recursive}
for i in 2:fieldcount(T)
key = fieldtype(T,i)
Si = fieldtype(S, key)
push!(output.args, recursive ? :($_propertynames($(Si))...) : :($fieldnames($(Si))...))
push!(output_args, recursive ? :($_propertynames($(Si))...) : :($fieldnames($(Si))...))
end
return output
return Expr(:tuple, output_args...)
end
_propertynames(x, delegated_fields) = _propertynames(typeof(x), delegated_fields)

Expand Down Expand Up @@ -88,6 +88,7 @@ If `ensure_unique == true`, throws an error when there are nonunique names in th

"""
function properties_interface(T; delegated_fields, recursive::Bool=false, ensure_unique::Bool=true, kwargs...)
@nospecialize
if haskey(kwargs, :is_mutable)
Base.depwarn("Passing `is_mutable` kwarg when `interface=properties` is now deprecated", Symbol("@define_interface"))
is_mutable = get_kwarg(Bool, kwargs, :is_mutable, false)
Expand All @@ -110,14 +111,15 @@ function properties_interface(T; delegated_fields, recursive::Bool=false, ensure
_setproperty = :($ForwardMethods._setproperty!($obj::$T, $name::Symbol, $value) = $ForwardMethods._setproperty!($obj, $delegated_fields_tuple_type, $name, $value))
setproperty = :($Base.setproperty!($obj::$T, $name::Symbol, $value) = $ForwardMethods._setproperty!($obj, $name, $value))

output = Expr(:block, line_num)
output_args = Any[]
if ensure_unique
push!(output.args, :($validate_properties($T, $delegated_fields_tuple_type)))
push!(output_args, :($validate_properties($T, $delegated_fields_tuple_type)))
end
push!(output.args, map(linenum!, (_propertynames, _propertynamesT, propertynames, _getproperty, getproperty))...)
push!(output_args, map(linenum!, (_propertynames, _propertynamesT, propertynames, _getproperty, getproperty))...)
if is_mutable
push!(output.args, map(linenum!, (_setproperty, setproperty))...)
push!(output_args, map(linenum!, (_setproperty, setproperty))...)
end
output = Expr(:block, line_num, output_args...)
return wrap_define_interface(T, :properties, output)
end

Expand All @@ -138,6 +140,7 @@ Any values provided in `omit` are excluded from the generator expression above.

"""
function equality_interface(T; omit::AbstractVector{Symbol}=Symbol[], equality_op::Symbol=:(==), compare_fields::Symbol=:fieldnames)
@nospecialize
equality_op in (:(==), :isequal) || error("equality_op (= $equality_op) must be one of (==, isequal)")
if compare_fields == :fieldnames
getvalue = :($Base.getfield)
Expand Down Expand Up @@ -166,22 +169,23 @@ end
@method_def_constant define_interface_method(::Val{::Symbol}) define_interfaces_available

function define_interface_expr(T, kwargs::Dict{Symbol,Any}=Dict{Symbol,Any}(); _sourceinfo)
@nospecialize
interfaces = interface_kwarg!(kwargs)
omit = omit_kwarg!(kwargs)
output = Expr(:block)
output_args = Any[]
interfaces_available = define_interfaces_available()
for interface in interfaces
interface in interfaces_available || error("No interface found with name $interface -- must be one of `$interfaces_available`")
f = define_interface_method(Val(interface))

current_line_num[] = _sourceinfo
try
push!(output.args, f(T; omit, kwargs...))
push!(output_args, f(T; omit, kwargs...))
finally
current_line_num[] = nothing
end
end
return output
return Expr(:block, output_args...)
end

"""
Expand Down
25 changes: 14 additions & 11 deletions src/forward_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ interface_at_macroexpand_time(x) = true
base_forward_expr(f, args...) = Expr(:call, f, args...)

function forward_interface_args(T)
@nospecialize
t = gensym(arg_placeholder)
obj_arg = object_argument(t, T)
type_arg = type_argument(T)
Expand Down Expand Up @@ -102,6 +103,7 @@ according to the value of `index_style_linear`
Any function names specified in `omit::AbstractVector{Symbol}` will not be defined
"""
function array_interface(T; index_style_linear::Bool, omit::AbstractVector{Symbol}=Symbol[])
@nospecialize
obj_arg, type_arg, call_expr = forward_interface_args(T)

method_signatures = Any[
Expand Down Expand Up @@ -242,13 +244,14 @@ function getfields_interface(T; field::Union{Nothing,Symbol}=nothing, omit::Abst
return wrap_define_interface(T, :getfields, Base.remove_linenums!(quote
local omit_fields = $(Expr(:tuple, QuoteNode.(omit)...))
local fields = fieldnames($T)
local def_fields_expr = Expr(:block)
local var = gensym("x")
local def_fields_expr_args = Any[]
for field in fields
if field ∉ omit_fields
push!(def_fields_expr.args, :($field($var::$$T) = Base.getfield($var, $(QuoteNode(field)))))
push!(def_fields_expr_args, :($field($var::$$T) = Base.getfield($var, $(QuoteNode(field)))))
end
end
local def_fields_expr = Expr(:block, def_fields_expr_args...)
eval(def_fields_expr)
nothing
end))
Expand All @@ -265,13 +268,14 @@ function setfields_interface(T; field::Union{Nothing,Symbol}=nothing, omit::Abst
return wrap_define_interface(T, :setfields, Base.remove_linenums!(quote
local omit_fields = $(Expr(:tuple, QuoteNode.(omit)...))
local fields = fieldnames($T)
local def_fields_expr = Expr(:block)
local var = gensym("x")
local def_fields_expr_args = Any[]
for field in fields
if field ∉ omit_fields
push!(def_fields_expr.args, :($(Symbol(string(field)*"!"))($var::$$T, value) = Base.setfield!($var, $(QuoteNode(field)), value)))
push!(def_fields_expr_args, :($(Symbol(string(field)*"!"))($var::$$T, value) = Base.setfield!($var, $(QuoteNode(field)), value)))
end
end
local def_fields_expr = Expr(:block, def_fields_expr_args...)
eval(def_fields_expr)
nothing
end))
Expand Down Expand Up @@ -309,8 +313,7 @@ function forward_interface_expr(T, kwargs::Dict{Symbol,Any}=Dict{Symbol,Any}();
field_funcs = nothing
end

_output = Expr(:block)

_output_args = Any[]
available_interfaces = forward_interfaces_available()

for interface in interfaces
Expand All @@ -325,16 +328,16 @@ function forward_interface_expr(T, kwargs::Dict{Symbol,Any}=Dict{Symbol,Any}();
isnothing(field_funcs.type_func) && error("Only fieldname mode for `field` (= $field) supported for interface (= $interface_value)")
end
signatures = f(T; omit, kwargs...)
output = Expr(:block)
output_args = Any[]
for signature in signatures
push!(output.args, forward_method_signature(T, field_funcs, map_func, signature; _sourceinfo))
push!(output_args, forward_method_signature(T, field_funcs, map_func, signature; _sourceinfo))
end
push!(_output.args, output)
push!(_output_args, Expr(:block, output_args...))
else
push!(_output.args, f(T; omit, kwargs...))
push!(_output_args, f(T; omit, kwargs...))
end
end
return _output
return Expr(:block, _output_args...)
end

"""
Expand Down
29 changes: 16 additions & 13 deletions src/forward_methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,24 +106,27 @@ function forward_method_signature(Type, field_funcs::FieldFuncExprs, map_func::F
if !found
error("No argument matching type $Type in input = `$input`")
end
func_body = Expr(:call, funcname)
func_body_args = Any[funcname]
if !isempty(kwargs)
push!(func_body.args, Expr(:parameters, kwargs...))
push!(func_body_args, Expr(:parameters, kwargs...))
end
push!(func_body.args, output_args...)
body_block = Expr(:block)
push!(func_body_args, output_args...)
func_body = Expr(:call, func_body_args...)
body_block_args = Any[]
if !isnothing(_sourceinfo)
push!(body_block.args, _sourceinfo)
push!(body_block_args, _sourceinfo)
end
push!(body_block.args, found_arg_expr)
push!(body_block_args, found_arg_expr)
mapped_body = !found_arg_is_type ? map_func(found_input_arg, func_body) : func_body
push!(body_block.args, mapped_body)
push!(body_block_args, mapped_body)
body_block = Expr(:block, body_block_args...)

new_sig = Expr(:call, funcname)
new_sig_args = Any[funcname]
if !isempty(kwargs)
push!(new_sig.args, Expr(:parameters, kwargs...))
push!(new_sig_args, Expr(:parameters, kwargs...))
end
push!(new_sig.args, input_args...)
push!(new_sig_args, input_args...)
new_sig = Expr(:call, new_sig_args...)
if !isnothing(whereparams)
new_sig = Expr(:where, new_sig, whereparams...)
end
Expand Down Expand Up @@ -226,7 +229,7 @@ function forward_methods_expr(Type, field_expr, args...; _sourceinfo=nothing)
method_exprs = args
end

output = Expr(:block)
output_args = Any[]
for arg in method_exprs
_args = @switch arg begin
@case Expr(:block, args...)
Expand All @@ -235,10 +238,10 @@ function forward_methods_expr(Type, field_expr, args...; _sourceinfo=nothing)
[arg]
end
for arg in _args
push!(output.args, forward_method_signature(Type, field_funcs, map_func, arg; _sourceinfo))
push!(output_args, forward_method_signature(Type, field_funcs, map_func, arg; _sourceinfo))
end
end
return output
return Expr(:block, output_args...)
end

"""
Expand Down
8 changes: 4 additions & 4 deletions src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,13 @@ replace_placeholder(x, replace_values) = (x, false)

function replace_placeholder(x::Expr, replace_values::Vector{<:Pair{Symbol,<:Any}})
replaced = false
new_expr = Expr(x.head)
new_expr_args = Any[]
for arg in x.args
new_arg, arg_replaced = replace_placeholder(arg, replace_values)
push!(new_expr.args, new_arg)
push!(new_expr_args, new_arg)
replaced |= arg_replaced
end
return new_expr, replaced
return Expr(x.head, new_expr_args...), replaced
end

identity_map_expr(obj_expr, forwarded_expr) = forwarded_expr
Expand Down Expand Up @@ -113,7 +113,7 @@ function omit_kwarg!(kwargs::Dict{Symbol,Any})
end
end

function get_kwarg(::Type{T}, kwargs, key::Symbol, default) where {T}
function get_kwarg(T::Type, kwargs, key::Symbol, default)
value = get(kwargs, key, default)
value isa T || error("$key (= $value) must be a $T, got typeof($key) = $(typeof(value))")
return value
Expand Down
1 change: 0 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
[deps]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a"
TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe"
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using ForwardMethods
using TestItemRunner

if VERSION ≥ v"1.9"
using ForwardMethods
using Aqua
Aqua.test_all(ForwardMethods)
end
Expand Down
2 changes: 1 addition & 1 deletion test/setup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using TestItemRunner, TestItems
@testsnippet SetupTest begin
using ForwardMethods.MLStyle

using JET, Test, TestingUtilities
using Test, TestingUtilities

macro test_throws_compat(ExceptionType, message, expr)
output = Expr(:block, __source__, :($Test.@test_throws $ExceptionType $expr))
Expand Down
4 changes: 1 addition & 3 deletions test/test_forward_methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,7 @@
@Test B([0])[1] == 0
c = C([1])
@Test length(c) == 1
@static if VERSION >= v"1.9"
@test_opt length(c)
end

end
end
end
Loading