Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion library/init/meta/environment.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Authors: Leonardo de Moura
-/
prelude
import init.meta.declaration init.meta.exceptional init.data.option.basic
import init.meta.rb_map
import init.meta.rb_map init.meta.pexpr

/-- An __environment__ contains all of the declarations and notation that have been defined so far. -/
meta constant environment : Type
Expand Down Expand Up @@ -143,6 +143,10 @@ But there are no `is_inductive`s which are not `is_ginductive`.
meta constant is_ginductive : environment → name → bool
/-- See the docstring for `projection_info`. -/
meta constant is_projection : environment → name → option projection_info

/-- Get the equations specifying a certain definition -/
meta constant defn_spec : environment → name → option pexpr

/-- Fold over declarations in the environment. -/
meta constant fold {α :Type} : environment → α → (declaration → α → α) → α
/-- `relation_info env n` returns some value if n is marked as a relation in the given environment.
Expand Down
4 changes: 2 additions & 2 deletions library/init/meta/expr.lean
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,14 @@ meta constant expr.lt : expr → expr → bool
meta constant expr.lex_lt : expr → expr → bool

/-- `expr.fold e a f`: Traverses each subexpression of `e`. The `nat` passed to the folder `f` is the binder depth. -/
meta constant expr.fold {α : Type} : expr → α → (expr → nat → α → α) → α
meta constant expr.fold {elab : opt_param bool tt} {α : Type} : expr elab → α → (expr elab → nat → α → α) → α
/-- `expr.replace e f`
Traverse over an expr `e` with a function `f` which can decide to replace subexpressions or not.
For each subexpression `s` in the expression tree, `f s n` is called where `n` is how many binders are present above the given subexpression `s`.
If `f s n` returns `none`, the children of `s` will be traversed.
Otherwise if `some s'` is returned, `s'` will replace `s` and this subexpression will not be traversed further.
-/
meta constant expr.replace : expr → (expr → nat → option expr) → expr
meta constant expr.replace {elab : opt_param bool tt} : expr elab → (expr elab → nat → option (expr elab)) → expr elab

/-- `abstract_local e n` replaces each instance of the local constant with unique (not pretty) name `n` in `e` with a de-Bruijn variable. -/
meta constant expr.abstract_local : expr → name → expr
Expand Down
7 changes: 7 additions & 0 deletions library/init/meta/tactic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1552,6 +1552,13 @@ meta def is_protected_decl (n : name) : tactic bool :=
do env ← get_env,
return $ env.is_protected n

/-- get the set of equations that specify a given
definition -/
meta def get_defn_spec (n : name) : tactic pexpr :=
do env ← get_env,
env.defn_spec n


/-- `add_defn_equations` adds a definition specified by a list of equations.

The arguments:
Expand Down
4 changes: 2 additions & 2 deletions src/frontends/lean/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ match_expr.cpp local_context_adapter.cpp decl_util.cpp definition_cmds.cpp
brackets.cpp tactic_notation.cpp info_manager.cpp json.cpp module_parser.cpp
interactive.cpp completion.cpp
user_notation.cpp user_command.cpp
widget.cpp)
widget.cpp eqn_api.cpp)
if(EMSCRIPTEN)
add_dependencies(lean_frontend gmp)
endif()
endif()
11 changes: 9 additions & 2 deletions src/frontends/lean/definition_cmds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Author: Leonardo de Moura
#include "library/compiler/rec_fn_macro.h"
#include "library/tactic/eqn_lemmas.h"
#include "frontends/lean/parser.h"
#include "frontends/lean/eqn_api.h"
#include "frontends/lean/tokens.h"
#include "frontends/lean/elaborator.h"
#include "frontends/lean/util.h"
Expand Down Expand Up @@ -475,7 +476,10 @@ static environment mutual_definition_cmd_core(parser & p, decl_cmd_kind kind, cm
return p.env();

bool recover_from_errors = true;
elaborator elab(env, p.get_options(), get_namespace(env) + mlocal_pp_name(fns[0]), metavar_context(), local_context(), recover_from_errors);
name full_name = get_namespace(env) + mlocal_pp_name(fns[0]);
env = store_eqn_spec(env, full_name, val);
p.set_env(env);
elaborator elab(env, p.get_options(), full_name, metavar_context(), local_context(), recover_from_errors);
buffer<expr> new_params;
elaborate_params(elab, params, new_params);
val = replace_locals_preserving_pos_info(val, params, new_params);
Expand Down Expand Up @@ -794,7 +798,10 @@ environment single_definition_cmd_core(parser_info & p, decl_cmd_kind kind, cmd_
return p.env();

bool recover_from_errors = p.m_error_recovery;
elaborator elab(env, p.get_options(), get_namespace(env) + mlocal_pp_name(fn), metavar_context(), local_context(), recover_from_errors);
name full_name = get_namespace(env) + mlocal_pp_name(fn);
env = store_eqn_spec(env, full_name, val);
p.set_env(env);
elaborator elab(env, p.get_options(), full_name, metavar_context(), local_context(), recover_from_errors);
buffer<expr> new_params;
elaborate_params(elab, params, new_params);
elab.freeze_local_instances();
Expand Down
105 changes: 105 additions & 0 deletions src/frontends/lean/eqn_api.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
Copyright (c) 2020 Simon Hudon. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.

Author: Simon Hudon
*/

#include <map>
#include "kernel/environment.h"
#include "library/module.h"
#include "library/vm/vm.h"
#include "library/vm/vm_name.h"
#include "library/vm/vm_environment.h"
#include "library/vm/vm_expr.h"
#include "library/vm/vm_option.h"
#include "library/kernel_serializer.h"

namespace lean {

struct eqn_info : public environment_extension {
std::map<name, expr> m_eqn_info;

eqn_info() { }
};

struct eqn_info_reg {
unsigned m_ext_id;
eqn_info_reg() {
m_ext_id = environment::register_extension(std::make_shared<eqn_info>(eqn_info()));
}
};

static eqn_info_reg * g_ext = nullptr;

static eqn_info const & get_extension(environment const & env) {
return static_cast<eqn_info const &>(env.get_extension(g_ext->m_ext_id));
}

static environment update(environment const & env, eqn_info const & ext) {
return env.update(g_ext->m_ext_id, std::make_shared<eqn_info>(ext));
}


struct eqn_info_modification : public modification {
LEAN_MODIFICATION("EQN_INFO")

name m_fn;
expr m_eqns;

eqn_info_modification(name const & fn, expr const & eqns): m_fn(fn), m_eqns(eqns) {}

void perform(environment & env) const override {
auto ext = get_extension(env);
ext.m_eqn_info.insert(mk_pair(m_fn, m_eqns));
env = update(env, ext);
}

void serialize(serializer & s) const override {
s << m_fn << m_eqns;
}

static std::shared_ptr<modification const> deserialize(deserializer & d) {
name fn; expr eqns;
d >> fn >> eqns;
return std::make_shared<eqn_info_modification>(fn, eqns);
}
};

optional<expr> get_eqn_spec(environment const & env, name const & n) {
auto ext = get_extension(env);
std::map<name, expr>::iterator eqn = ext.m_eqn_info.find(n);
if (eqn != ext.m_eqn_info.end()) {
return optional<expr>(eqn->second);
} else {
return optional<expr>();
}
}

vm_obj environment_get_eqn_spec(vm_obj const & env, vm_obj const & n) {
environment env_ = to_env(env);
name n_ = to_name(n);
if (auto r = get_eqn_spec(env_, n_)) {
return mk_vm_some(to_obj(*r));
} else {
return mk_vm_none();
}
}

environment store_eqn_spec(environment const & env, name const & n, expr const & e) {
return module::add_and_perform(env, std::make_shared<eqn_info_modification>(n, e));
}

void initialize_eqn_api() {
g_ext = new eqn_info_reg();
eqn_info_modification::init();
DECLARE_VM_BUILTIN(name({"environment", "defn_spec"}), environment_get_eqn_spec);
}

void finalize_eqn_api() {
eqn_info_modification::finalize();
delete g_ext;
}


}
17 changes: 17 additions & 0 deletions src/frontends/lean/eqn_api.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/*
Copyright (c) 2020 Simon Hudon. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.

Author: Simon Hudon
*/
#pragma once
#include "kernel/environment.h"

namespace lean {

void initialize_eqn_api();
void finalize_eqn_api();
environment store_eqn_spec(environment const & env, name const & n, expr const & e);
optional<expr> get_eqn_spec(environment const & env, name const & n);

}
3 changes: 3 additions & 0 deletions src/frontends/lean/init_module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Author: Leonardo de Moura
#include "frontends/lean/user_notation.h"
#include "frontends/lean/user_command.h"
#include "frontends/lean/widget.h"
#include "frontends/lean/eqn_api.h"

namespace lean {
void initialize_frontend_lean_module() {
Expand Down Expand Up @@ -61,8 +62,10 @@ void initialize_frontend_lean_module() {
initialize_user_notation();
initialize_user_command();
initialize_widget();
initialize_eqn_api();
}
void finalize_frontend_lean_module() {
finalize_eqn_api();
finalize_user_command();
finalize_user_notation();
finalize_completion();
Expand Down
4 changes: 2 additions & 2 deletions src/library/vm/vm_expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ vm_obj expr_lex_lt(vm_obj const & o1, vm_obj const & o2) {
return mk_vm_bool(is_lt(to_expr(o1), to_expr(o2), false));
}

vm_obj expr_fold(vm_obj const &, vm_obj const & e, vm_obj const & a, vm_obj const & fn) {
vm_obj expr_fold(vm_obj const &, vm_obj const &, vm_obj const & e, vm_obj const & a, vm_obj const & fn) {
vm_obj r = a;
for_each(to_expr(e), [&](expr const & o, unsigned d) {
r = invoke(fn, to_obj(o), mk_vm_nat(d), r);
Expand All @@ -246,7 +246,7 @@ vm_obj expr_fold(vm_obj const &, vm_obj const & e, vm_obj const & a, vm_obj cons
return r;
}

vm_obj expr_replace(vm_obj const & e, vm_obj const & fn) {
vm_obj expr_replace(vm_obj const &, vm_obj const & e, vm_obj const & fn) {
expr r = replace(to_expr(e), [&](expr const & o, unsigned d) {
vm_obj new_o = invoke(fn, to_obj(o), mk_vm_nat(d));
if (is_none(new_o))
Expand Down
36 changes: 36 additions & 0 deletions tests/lean/eqn_api.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@

import .eqn_api2

def foo_bar (a : Type) (x : a) : option a -> a
| none := x
| (some y) := y

def all : bool -> bool -> bool -> bool -> bool
| ff _ _ _ := ff
| _ ff _ _ := ff
| _ _ ff _ := ff
| _ _ _ ff := ff
| _ _ _ _ := tt

open tactic

meta def replace_internal_name (e : pexpr) : pexpr :=
expr.replace e $ λ e i,
match e with
| expr.const (name.mk_numeral _ _) ls := some $ expr.const `_ ls
| _ := none
end

run_cmd do
e ← get_env,
trace "foo",
environment.defn_spec e ``foo >>= trace ∘ replace_internal_name,
trace "foo_bar",
environment.defn_spec e ``foo_bar >>= trace ∘ replace_internal_name,
trace "all",
environment.defn_spec e ``all >>= trace ∘ replace_internal_name,
trace "f",
environment.defn_spec e ``f >>= trace ∘ replace_internal_name,
trace "g",
environment.defn_spec e ``g >>= trace ∘ replace_internal_name,
skip
12 changes: 12 additions & 0 deletions tests/lean/eqn_api.lean.expected.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
eqn_api.lean:24:0: error: failed
state:
⊢ true
foo
equations (fun (foo : (option a) -> a), (equation ((«@» foo) option.none) x)) (fun (foo : (option a) -> a) (y : _), (equation ((«@» foo) (option.some y)) y))
foo_bar
equations (fun (foo_bar : (option a) -> a), (equation ((«@» foo_bar) option.none) x)) (fun (foo_bar : (option a) -> a) (y : _), (equation ((«@» foo_bar) (option.some y)) y))
all
equations (fun (all : bool -> bool -> bool -> bool -> bool) (_x : _) (_x_1 : _) (_x_2 : _), (equation ((«@» all) bool.ff _x _x_1 _x_2) bool.ff)) (fun (all : bool -> bool -> bool -> bool -> bool) (_x : _) (_x_1 : _) (_x_2 : _), (equation ((«@» all) _x bool.ff _x_1 _x_2) bool.ff)) (fun (all : bool -> bool -> bool -> bool -> bool) (_x : _) (_x_1 : _) (_x_2 : _), (equation ((«@» all) _x _x_1 bool.ff _x_2) bool.ff)) (fun (all : bool -> bool -> bool -> bool -> bool) (_x : _) (_x_1 : _) (_x_2 : _), (equation ((«@» all) _x _x_1 _x_2 bool.ff) bool.ff)) (fun (all : bool -> bool -> bool -> bool -> bool) (_x : _) (_x_1 : _) (_x_2 : _) (_x_3 : _), (equation ((«@» all) _x _x_1 _x_2 _x_3) bool.tt))
f
equations (fun (f : (list nat) -> nat) (g : (list nat) -> nat), (equation ((«@» f) list.nil) 0)) (fun (f : (list nat) -> nat) (g : (list nat) -> nat) (x : _) (xs : _), (equation ((«@» f) (list.cons x xs)) (has_add.add x (g xs)))) (fun (f : (list nat) -> nat) (g : (list nat) -> nat), (equation ((«@» g) list.nil) 0)) (fun (f : (list nat) -> nat) (g : (list nat) -> nat) (x : _) (xs : _), (equation ((«@» g) (list.cons x xs)) (f xs)))
g
12 changes: 12 additions & 0 deletions tests/lean/eqn_api2.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@

mutual def f, g
with f : list ℕ -> ℕ
| [] := 0
| (x :: xs) := x + g xs
with g : list ℕ -> ℕ
| [] := 0
| (x :: xs) := f xs

def foo (a : Type) (x : a) : option a -> a
| none := x
| (some y) := y
Empty file.