From 8db2eafa6273a9e340be30075946441ad8e14ffd Mon Sep 17 00:00:00 2001 From: Georgiy Komarov Date: Wed, 21 Sep 2022 15:29:30 +0300 Subject: [PATCH] feat(formatter): Preserve comments in `Formatter` This is not finished work. We need to support: + top-level comments + comments between components And test it properly. Closes #1086 --- src/base/FrontEndParser.ml | 1 + src/base/ScillaLexer.mll | 26 +- src/formatter/ExtendedSyntax.ml | 559 ++++++++++++++++++++++++++++++++ src/formatter/Formatter.ml | 149 +++++---- src/runners/scilla_fmt.ml | 7 +- 5 files changed, 675 insertions(+), 67 deletions(-) create mode 100644 src/formatter/ExtendedSyntax.ml diff --git a/src/base/FrontEndParser.ml b/src/base/FrontEndParser.ml index ed9f69084..d721eb29f 100644 --- a/src/base/FrontEndParser.ml +++ b/src/base/FrontEndParser.ml @@ -87,4 +87,5 @@ module ScillaFrontEndParser (Literal : ScillaLiteral) = struct let parse_expr_from_stdin () = parse_stdin Parser.Incremental.exp_term let parse_lmodule filename = parse_file Parser.Incremental.lmodule filename let parse_cmodule filename = parse_file Parser.Incremental.cmodule filename + let get_comments () = Lexer.get_comments () end diff --git a/src/base/ScillaLexer.mll b/src/base/ScillaLexer.mll index 7165186a8..3a819cde0 100644 --- a/src/base/ScillaLexer.mll +++ b/src/base/ScillaLexer.mll @@ -29,6 +29,11 @@ module MkLexer (S : ParserUtil.Syn) = struct exception Error of string + let comments = ref [] + let add_comment start_p s = + let loc = ErrorUtils.toLoc start_p in + comments := (loc, s) :: !comments + let get_comments () = List.rev !comments } let digit = ['0'-'9'] @@ -55,7 +60,7 @@ rule read = (* Whitespaces *) | newline { new_line lexbuf; read lexbuf } - | "(*" { comment [lexbuf.lex_curr_p] lexbuf } + | "(*" { comment (Buffer.create 50) [lexbuf.lex_start_p] lexbuf } | white { read lexbuf } (* Numbers and hashes *) @@ -148,16 +153,19 @@ and read_string buf = | eof { raise (Error ("String is not terminated")) } (* Nested comments, keeping a list of where comments open *) -and comment braces = +and comment buf braces = parse - | "(*" { comment (lexbuf.lex_curr_p::braces) lexbuf} + | "(*" { comment buf (lexbuf.lex_curr_p::braces) lexbuf } | "*)" { match braces with - _::[] -> read lexbuf - | _ -> comment (List.tl_exn braces) lexbuf } - | newline { new_line lexbuf; comment braces lexbuf} - | _ { comment braces lexbuf} - | eof { lexbuf.lex_curr_p <- List.hd_exn braces; raise (Error ("Comment unfinished"))} + p::[] -> add_comment p (Buffer.contents buf); + read lexbuf + | _ -> comment buf (List.tl_exn braces) lexbuf } + | newline { new_line lexbuf; comment buf braces lexbuf } + | _ { Buffer.add_string buf (Lexing.lexeme lexbuf); + comment buf braces lexbuf } + | eof { lexbuf.lex_curr_p <- List.hd_exn braces; + raise (Error ("Comment unfinished")) } { end -} \ No newline at end of file +} diff --git a/src/formatter/ExtendedSyntax.ml b/src/formatter/ExtendedSyntax.ml new file mode 100644 index 000000000..432ad482f --- /dev/null +++ b/src/formatter/ExtendedSyntax.ml @@ -0,0 +1,559 @@ +(* + This file is part of scilla. + + Copyright (c) 2018 - present Zilliqa Research Pvt. Ltd. + + scilla is free software: you can redistribute it and/or modify it under the + terms of the GNU General Public License as published by the Free Software + Foundation, either version 3 of the License, or (at your option) any later + version. + + scilla is distributed in the hope that it will be useful, but WITHOUT ANY + WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR + A PARTICULAR PURPOSE. See the GNU General Public License for more details. + + You should have received a copy of the GNU General Public License along with + scilla. If not, see . +*) + +open Core +open Sexplib.Std +open Scilla_base +open ErrorUtils +open Literal +open GasCharge + +(** Annotated Scilla syntax extended with comment nodes. *) +module ExtendedScillaSyntax + (SR : Syntax.Rep) + (ER : Syntax.Rep) + (Lit : ScillaLiteral) = +struct + module SLiteral = Lit + module SType = SLiteral.LType + module SIdentifier = SType.TIdentifier + module SGasCharge = ScillaGasCharge (SIdentifier.Name) + + type comment_pos = ComLeft | ComAbove | ComRight [@@deriving sexp] + type comment = loc * string * comment_pos [@@deriving sexp] + type annot_comment = comment list [@@deriving sexp] + + type 'a id_ann = 'a SIdentifier.t * annot_comment [@@deriving sexp] + (** Annotated identifier that may be commented *) + + (*******************************************************) + (* Expressions *) + (*******************************************************) + + type payload = MLit of SLiteral.t | MVar of ER.rep id_ann [@@deriving sexp] + + type pattern = + | Wildcard + | Binder of ER.rep id_ann + | Constructor of SR.rep id_ann * pattern list + [@@deriving sexp] + + type expr_annot = expr * ER.rep * annot_comment + + and expr = + | Literal of SLiteral.t + | Var of ER.rep id_ann + | Let of ER.rep id_ann * SType.t option * expr_annot * expr_annot + | Message of (string * payload) list + | Fun of ER.rep id_ann * SType.t * expr_annot + | App of ER.rep id_ann * ER.rep id_ann list + | Constr of SR.rep id_ann * SType.t list * ER.rep id_ann list + | MatchExpr of ER.rep id_ann * (pattern * expr_annot) list + | Builtin of ER.rep Syntax.builtin_annot * SType.t list * ER.rep id_ann list + | TFun of ER.rep id_ann * expr_annot + | TApp of ER.rep id_ann * SType.t list + | Fixpoint of ER.rep id_ann * SType.t * expr_annot + | GasExpr of SGasCharge.gas_charge * expr_annot + [@@deriving sexp] + + (*******************************************************) + (* Statements *) + (*******************************************************) + + type bcinfo_query = + | CurBlockNum + | ChainID + | Timestamp of ER.rep id_ann + | ReplicateContr of (ER.rep id_ann * ER.rep id_ann) + [@@deriving sexp] + + type stmt_annot = stmt * SR.rep * annot_comment + + and stmt = + | Load of ER.rep id_ann * ER.rep id_ann + | RemoteLoad of ER.rep id_ann * ER.rep id_ann * ER.rep id_ann + | Store of ER.rep id_ann * ER.rep id_ann + | Bind of ER.rep id_ann * expr_annot + | MapUpdate of ER.rep id_ann * ER.rep id_ann list * ER.rep id_ann option + | MapGet of ER.rep id_ann * ER.rep id_ann * ER.rep id_ann list * bool + | RemoteMapGet of + ER.rep id_ann + * ER.rep id_ann + * ER.rep id_ann + * ER.rep id_ann list + * bool + | MatchStmt of ER.rep id_ann * (pattern * stmt_annot list) list + | ReadFromBC of ER.rep id_ann * bcinfo_query + | TypeCast of ER.rep id_ann * ER.rep id_ann * SType.t + | AcceptPayment (** [AcceptPayment] is an [accept] statement. *) + | Iterate of ER.rep id_ann * SR.rep id_ann + | SendMsgs of ER.rep id_ann + | CreateEvnt of ER.rep id_ann + | CallProc of SR.rep id_ann * ER.rep id_ann list + | Throw of ER.rep id_ann option + | GasStmt of SGasCharge.gas_charge + [@@deriving sexp] + + (*******************************************************) + (* Contracts *) + (*******************************************************) + + type component = { + comp_type : Syntax.component_type; + comp_name : SR.rep id_ann; + comp_params : (ER.rep id_ann * SType.t) list; + comp_body : stmt_annot list; + } + [@@deriving sexp] + + type ctr_def = { cname : ER.rep id_ann; c_arg_types : SType.t list } + [@@deriving sexp] + + type lib_entry = + | LibVar of ER.rep id_ann * SType.t option * expr_annot + | LibTyp of ER.rep id_ann * ctr_def list + [@@deriving sexp] + + type library = { lname : SR.rep id_ann; lentries : lib_entry list } + [@@deriving sexp] + + type contract = { + cname : SR.rep id_ann; + cparams : (ER.rep id_ann * SType.t) list; + cconstraint : expr_annot; + cfields : (ER.rep id_ann * SType.t * expr_annot) list; + ccomps : component list; + } + [@@deriving sexp] + + type cmodule = { + smver : int; + (* file_comment : string option; *) + (* lib_comment : string option; *) + libs : library option; + elibs : (SR.rep id_ann * SR.rep id_ann option) list; + (* contr_comment : string option; *) + contr : contract; + } + [@@deriving sexp] + (** The structure of the extended [cmodule] is: + scilla_version 0 + + (* File comment *) + + import X + + (* Library comment *) + library ExampleLib + + (* Contract comment *) + contract ExampleContr() + *) + + (* Library module *) + type lmodule = { + smver : int; + (* Scilla major version of the library. *) + (* List of imports / external libs with an optional namespace. *) + elibs : (SR.rep id_ann * SR.rep id_ann option) list; + libs : library; (* lib functions defined in the module *) + } + [@@deriving sexp] + + (* A tree of libraries linked to their dependents *) + type libtree = { + libn : library; + (* The library this node represents *) + deps : libtree list; (* List of dependent libraries *) + } +end + +module ExtendedScillaSyntaxTransformer + (SR : Syntax.Rep) + (ER : Syntax.Rep) + (Lit : Literal.ScillaLiteral) = +struct + module Syn = Syntax.ScillaSyntax (SR) (ER) (Lit) + module ExtSyn = ExtendedScillaSyntax (SR) (ER) (Lit) + module SType = Lit.LType + module SIdentifier = SType.TIdentifier + + type t = { mutable comments : (loc * string) list } + + let mk comments = + (* Comments are already sorted by line number and column, because the lexer + works by this way. *) + { comments } + + (** Creates comment annotations that must be placed between [loc_start] and + [loc_end] and removes them from the [tr.comments] list. *) + let place_comments tr loc_start loc_end = + let rec aux acc = function + (* Placing comments left *) + | (loc, s) :: xs + when (* (* com *) start end *) + phys_equal loc.lnum loc_start.lnum && loc.cnum < loc_start.cnum -> + tr.comments <- xs; + aux ((loc, s, ExtSyn.ComLeft) :: acc) xs + (* Placing comments above *) + | (loc, s) :: xs + when (* (* com *) (* com *) + start end start + end *) + loc.lnum < loc_start.lnum + || (* (* com *) start end (* com *) start + end *) + (loc.lnum <= loc_start.lnum && loc.cnum < loc_start.cnum) -> + tr.comments <- xs; + aux ((loc, s, ExtSyn.ComAbove) :: acc) xs + (* Placing comments right *) + | (loc, s) :: xs + when (* start (* com *) + end *) + phys_equal loc.lnum loc_start.lnum + && loc.lnum > loc_end.lnum && loc.cnum > loc_start.cnum + || (* start (* com *) end *) + phys_equal loc.lnum loc_start.lnum + && phys_equal loc.lnum loc_end.lnum + && loc.cnum > loc_start.cnum && loc.cnum < loc_end.cnum -> + tr.comments <- xs; + aux ((loc, s, ExtSyn.ComRight) :: acc) xs + | _ -> acc + in + if ErrorUtils.compare_loc loc_start loc_end > 0 then [] + else aux [] tr.comments + + let extend_id ?(rep_end = None) tr id get_loc = + let id_loc = get_loc (SIdentifier.get_rep id) in + let end_loc = + Option.value_map rep_end ~default:id_loc ~f:(fun rep -> get_loc rep) + in + let comments = place_comments tr id_loc end_loc in + (id, comments) + + let extend_er_id ?(rep_end = None) tr id = extend_id tr id ~rep_end ER.get_loc + let extend_sr_id ?(rep_end = None) tr id = extend_id tr id ~rep_end SR.get_loc + + let extend_payload tr = function + | Syn.MLit l -> ExtSyn.MLit l + | Syn.MVar v -> ExtSyn.MVar (extend_er_id tr v) + + let rec extend_pattern tr = function + | Syn.Wildcard -> ExtSyn.Wildcard + | Syn.Binder id -> ExtSyn.Binder (extend_er_id tr id) + | Syn.Constructor (id, args) -> + let args' = List.map args ~f:(fun arg -> extend_pattern tr arg) in + ExtSyn.Constructor (extend_sr_id tr id, args') + + let rec extend_expr tr (e, ann) = + let comment ?(rep_end = ann) () = + place_comments tr (ER.get_loc ann) (ER.get_loc rep_end) + in + match e with + | Syn.Literal l -> (ExtSyn.Literal l, ann, comment ()) + | Syn.Var id -> + let c = comment () in + let id' = extend_er_id tr id in + (ExtSyn.Var id', ann, c) + | Syn.Let (id, ty, (lhs, lhs_rep), rhs) -> + let c = comment () ~rep_end:(SIdentifier.get_rep id) in + let id' = extend_er_id tr id ~rep_end:(Some lhs_rep) in + let lhs' = extend_expr tr (lhs, lhs_rep) in + let rhs' = extend_expr tr rhs in + (ExtSyn.Let (id', ty, lhs', rhs'), ann, c) + | Syn.Message msgs -> + let c = comment () in + let msgs' = + List.map msgs ~f:(fun (s, pld) -> (s, extend_payload tr pld)) + in + (ExtSyn.Message msgs', ann, c) + | Syn.Fun (id, ty, (body, body_rep)) -> + let c = comment ~rep_end:(SIdentifier.get_rep id) () in + let id' = extend_er_id tr id ~rep_end:(Some body_rep) in + let body' = extend_expr tr (body, body_rep) in + (ExtSyn.Fun (id', ty, body'), ann, c) + | Syn.App (id, args) -> + let c = comment () in + let id' = extend_er_id tr id in + let args' = List.map args ~f:(fun arg -> extend_er_id tr arg) in + (ExtSyn.App (id', args'), ann, c) + | Syn.Constr (id, tys, args) -> + let c = comment () in + let id' = extend_sr_id tr id in + let args' = List.map args ~f:(fun arg -> extend_er_id tr arg) in + (ExtSyn.Constr (id', tys, args'), ann, c) + | Syn.MatchExpr (id, arms) -> + let c = comment () in + let id' = extend_er_id tr id in + let arms' = + List.map arms ~f:(fun (pat, body) -> + (extend_pattern tr pat, extend_expr tr body)) + in + (ExtSyn.MatchExpr (id', arms'), ann, c) + | Syn.Builtin (builtin, ty, args) -> + let c = comment () in + let args' = List.map args ~f:(fun arg -> extend_er_id tr arg) in + (ExtSyn.Builtin (builtin, ty, args'), ann, c) + | Syn.TFun (id, (body, body_rep)) -> + let c = comment ~rep_end:body_rep () in + let id' = extend_er_id tr id in + let body' = extend_expr tr (body, body_rep) in + (ExtSyn.TFun (id', body'), ann, c) + | Syn.TApp (id, tys) -> + let c = comment () in + let id' = extend_er_id tr id in + (ExtSyn.TApp (id', tys), ann, c) + | Syn.Fixpoint (id, ty, (body, body_rep)) -> + let c = comment ~rep_end:body_rep () in + let id' = extend_er_id tr id in + let body' = extend_expr tr (body, body_rep) in + (ExtSyn.Fixpoint (id', ty, body'), ann, c) + | Syn.GasExpr (gc, (body, body_rep)) -> + let c = comment ~rep_end:body_rep () in + let body' = extend_expr tr (body, body_rep) in + let gc' = + Syn.SGasCharge.sexp_of_gas_charge gc + |> ExtSyn.SGasCharge.gas_charge_of_sexp + in + (ExtSyn.GasExpr (gc', body'), ann, c) + + let extend_bcinfo_query tr = function + | Syn.CurBlockNum -> ExtSyn.CurBlockNum + | Syn.ChainID -> ExtSyn.ChainID + | Syn.Timestamp id -> ExtSyn.Timestamp (extend_er_id tr id) + | Syn.ReplicateContr (addr, param) -> + let addr' = extend_er_id tr addr in + let param' = extend_er_id tr param in + ExtSyn.ReplicateContr (addr', param') + + let rec extend_stmt tr (s, ann) = + let comment loc_end = place_comments tr (SR.get_loc ann) loc_end in + let loc_end_er id = SIdentifier.get_rep id |> ER.get_loc in + let loc_end_sr id = SIdentifier.get_rep id |> SR.get_loc in + match s with + | Syn.Load (lhs, rhs) -> + let c = comment (loc_end_er lhs) in + let lhs' = extend_er_id tr lhs in + let rhs' = extend_er_id tr rhs in + (ExtSyn.Load (lhs', rhs'), ann, c) + | Syn.RemoteLoad (lhs, addr, rhs) -> + let c = comment (loc_end_er lhs) in + let lhs' = + extend_er_id tr lhs ~rep_end:(Some (SIdentifier.get_rep addr)) + in + let addr' = + extend_er_id tr addr ~rep_end:(Some (SIdentifier.get_rep rhs)) + in + let rhs' = extend_er_id tr rhs in + (ExtSyn.RemoteLoad (lhs', addr', rhs'), ann, c) + | Syn.Store (lhs, rhs) -> + let c = comment (loc_end_er lhs) in + let lhs' = extend_er_id tr lhs in + let rhs' = extend_er_id tr rhs in + (ExtSyn.Store (lhs', rhs'), ann, c) + | Syn.Bind (id, body) -> + let c = comment (loc_end_er id) in + let id' = extend_er_id tr id in + let ea' = extend_expr tr body in + (ExtSyn.Bind (id', ea'), ann, c) + | Syn.MapUpdate (m, keys, v) -> + let c = comment (loc_end_er m) in + let m' = extend_er_id tr m in + let keys' = List.map keys ~f:(fun k -> extend_er_id tr k) in + let v' = + Option.value_map v ~default:None ~f:(fun v -> + Some (extend_er_id tr v)) + in + (ExtSyn.MapUpdate (m', keys', v'), ann, c) + | Syn.MapGet (v, m, keys, retrieve) -> + let c = comment (loc_end_er v) in + let v' = extend_er_id tr v in + let m' = extend_er_id tr m in + let keys' = List.map keys ~f:(fun k -> extend_er_id tr k) in + (ExtSyn.MapGet (v', m', keys', retrieve), ann, c) + | Syn.RemoteMapGet (v, addr, m, keys, retrieve) -> + let c = comment (loc_end_er v) in + let v' = extend_er_id tr v in + let addr' = extend_er_id tr addr in + let m' = extend_er_id tr m in + let keys' = List.map keys ~f:(fun k -> extend_er_id tr k) in + (ExtSyn.RemoteMapGet (v', addr', m', keys', retrieve), ann, c) + | Syn.MatchStmt (id, arms) -> + let c = comment (loc_end_er id) in + let id' = extend_er_id tr id in + let arms' = + List.map arms ~f:(fun (pat, stmts) -> + let pat' = extend_pattern tr pat in + let stmts' = + List.map stmts ~f:(fun stmt -> extend_stmt tr stmt) + in + (pat', stmts')) + in + (ExtSyn.MatchStmt (id', arms'), ann, c) + | Syn.ReadFromBC (id, q) -> + let c = comment (loc_end_er id) in + let id' = extend_er_id tr id in + let q' = extend_bcinfo_query tr q in + (ExtSyn.ReadFromBC (id', q'), ann, c) + | Syn.TypeCast (id, addr, ty) -> + let c = comment (loc_end_er id) in + let id' = extend_er_id tr id in + let addr' = extend_er_id tr addr in + (ExtSyn.TypeCast (id', addr', ty), ann, c) + | Syn.AcceptPayment -> + let c = comment (SR.get_loc ann) in + (ExtSyn.AcceptPayment, ann, c) + | Syn.Iterate (l, f) -> + let c = comment (loc_end_er l) in + let l' = extend_er_id tr l in + let f' = extend_sr_id tr f in + (ExtSyn.Iterate (l', f'), ann, c) + | Syn.SendMsgs id -> + let c = comment (loc_end_er id) in + let id' = extend_er_id tr id in + (ExtSyn.SendMsgs id', ann, c) + | Syn.CreateEvnt id -> + let c = comment (loc_end_er id) in + let id' = extend_er_id tr id in + (ExtSyn.CreateEvnt id', ann, c) + | Syn.CallProc (id, args) -> + let c = comment (loc_end_sr id) in + let id' = extend_sr_id tr id in + let args' = List.map args ~f:(fun arg -> extend_er_id tr arg) in + (ExtSyn.CallProc (id', args'), ann, c) + | Syn.Throw id_opt -> ( + match id_opt with + | Some id -> + let c = comment (loc_end_er id) in + let id' = extend_er_id tr id in + (ExtSyn.Throw (Some id'), ann, c) + | None -> + let c = comment (SR.get_loc ann) in + (ExtSyn.Throw None, ann, c)) + | Syn.GasStmt gc -> + let c = comment (SR.get_loc ann) in + let gc' = + Syn.SGasCharge.sexp_of_gas_charge gc + |> ExtSyn.SGasCharge.gas_charge_of_sexp + in + (ExtSyn.GasStmt gc', ann, c) + + let extend_ctr_def tr (ctr : Syn.ctr_def) = + let cname' = extend_er_id tr ctr.cname in + { ExtSyn.cname = cname'; c_arg_types = ctr.c_arg_types } + + let extend_lentry tr = function + | Syn.LibVar (id, ty_opt, ea) -> + let id' = extend_er_id tr id in + let ea' = extend_expr tr ea in + ExtSyn.LibVar (id', ty_opt, ea') + | Syn.LibTyp (id, ctrs) -> + let id' = extend_er_id tr id in + let ctrs' = List.map ctrs ~f:(fun ctr -> extend_ctr_def tr ctr) in + ExtSyn.LibTyp (id', ctrs') + + let extend_lib tr (lib : Syn.library) = + let lname' = extend_sr_id tr lib.lname in + let lentries' = + List.map lib.lentries ~f:(fun lentry -> extend_lentry tr lentry) + in + { ExtSyn.lname = lname'; lentries = lentries' } + + let extend_elib tr elib = + let import, import_as = elib in + let import' = extend_sr_id tr import in + let import_as' = + Option.value_map import_as ~default:None ~f:(fun id -> + Some (extend_sr_id tr id)) + in + (import', import_as') + + let extend_component tr comp = + let comp_type = comp.Syn.comp_type in + let comp_name = extend_sr_id tr comp.comp_name in + let comp_params = + List.map comp.comp_params ~f:(fun (id, ty) -> (extend_er_id tr id, ty)) + in + let comp_body = + List.map comp.comp_body ~f:(fun stmt -> extend_stmt tr stmt) + in + { ExtSyn.comp_type; comp_name; comp_params; comp_body } + + let extend_contract tr (contr : Syn.contract) : ExtSyn.contract = + let cname = extend_sr_id tr contr.cname in + let cparams = + List.map contr.cparams ~f:(fun (id, ty) -> (extend_er_id tr id, ty)) + in + let cconstraint = extend_expr tr contr.cconstraint in + let cfields = + List.map contr.cfields ~f:(fun (id, ty, init) -> + (extend_er_id tr id, ty, extend_expr tr init)) + in + let ccomps = List.map contr.ccomps ~f:(fun c -> extend_component tr c) in + { cname; cparams; cconstraint; cfields; ccomps } + + (* (** Extracts the file-level comment of the [cmod] based on its locations. *) *) + (* let extract_file_comment tr (cmod : Syn.cmodule) : string option = *) + (* match List.hd tr.comments with *) + (* | Some (comment_loc, comment) -> ( *) + (* (* The file-level comment must be above the first import if there are *) + (* any, or above the library definition or the contract definition *) + (* otherwise. *) *) + (* match cmod.elibs with *) + (* | [] when List.is_empty cmod.elibs && Option.is_some cmod.libs -> *) + (* (* scilla_version *) + (* (* File comment *) *) + (* (* Library comment *) *) + (* library Example *) *) + (* None *) + (* | [] when List.is_empty cmod.elibs && Option.is_none cmod.libs -> *) + (* (* scilla_version *) + (* (* File comment *) *) + (* (* Contract comment *) *) + (* library Example *) *) + (* None *) + (* | [] when not @@ List.is_empty cmod.elibs -> *) + (* (* scilla_version *) + (* (* File comment *) *) + (* import X *) *) + (* None *) + (* | _ -> *) + (* let first_import_loc = *) + (* List.hd_exn cmod.elibs |> fun (id, _) -> *) + (* SR.get_loc (SIdentifier.get_rep id) *) + (* in *) + (* if first_import_loc.lnum > comment_loc.lnum then ( *) + (* tr.comments <- List.tl_exn tr.comments; *) + (* Some comment) *) + (* else None) *) + (* | None -> None (* no comments in this file *) *) + + let extend_cmodule tr (cmod : Syn.cmodule) : ExtSyn.cmodule = + let smver = cmod.smver in + let elibs = List.map cmod.elibs ~f:(fun l -> extend_elib tr l) in + let libs = + Option.value_map cmod.libs ~default:None ~f:(fun l -> + Some (extend_lib tr l)) + in + let contr = extend_contract tr cmod.contr in + { smver; libs; elibs; contr } +end + +module LocalLiteralTransformer = + ExtendedScillaSyntaxTransformer (ParserUtil.ParserRep) (ParserUtil.ParserRep) + (Literal.LocalLiteral) diff --git a/src/formatter/Formatter.ml b/src/formatter/Formatter.ml index 3df4c047b..3c9683235 100644 --- a/src/formatter/Formatter.ml +++ b/src/formatter/Formatter.ml @@ -27,8 +27,8 @@ open PPrint module Format (SR : Syntax.Rep) (ER : Syntax.Rep) (Lit : Literal.ScillaLiteral) = struct - (* instantiated syntax *) - module Ast = Syntax.ScillaSyntax (SR) (ER) (Lit) + (* instantiated syntax extended with comments *) + module Ast = ExtendedSyntax.ExtendedScillaSyntax (SR) (ER) (Lit) module type DOC = sig val of_type : Ast.SType.t -> PPrint.document @@ -92,12 +92,44 @@ struct (* Add parentheses only if the condition if true *) let parens_if cond doc = if cond then parens doc else doc + (** Add formatted [comments] around [doc]. *) + let wrap_comments comments doc = + let comment = enclose !^"(*" !^"*)" in + let spaced s = + let has_prefix prefix = String.is_prefix s ~prefix in + let has_suffix suffix = String.is_suffix s ~suffix in + let s = if has_prefix " " || has_prefix "*" then s else " " ^ s in + let s = if has_suffix " " || has_suffix "*" then s else s ^ " " in + s + in + let left, above, right = + List.fold_left comments + ~init:([],[],[]) + ~f:(fun (acc_l, acc_a, acc_r) -> function + | (_, s, Ast.ComLeft) -> + acc_l @ [comment !^(spaced s); space], acc_a, acc_r + | (_, s, Ast.ComAbove) -> + acc_l, (comment !^(spaced s))::acc_a, acc_r + | (_, s, Ast.ComRight) -> + acc_l, acc_a, [space; comment !^(spaced s)] @ acc_r) + |> fun (l, a, r) -> + let a' = if List.is_empty a then empty + else (concat_map (fun c -> c ^^^ hardline) a) + in + let l' = concat l in + let r' = concat r in + l', a', r' + in + concat [above; left; doc; right] + let of_builtin b = !^(Syntax.pp_builtin b) let of_id id = !^(Ast.SIdentifier.as_error_string id) - let of_ids ids = - separate_map space of_id ids + let of_ann_id (id, comments) = of_id id |> wrap_comments comments + + let of_ann_ids ids = + separate_map space of_ann_id ids let rec of_type_with_prec p typ = let open Ast.SType in @@ -150,7 +182,7 @@ struct let of_types typs ~sep = group @@ separate_map sep (fun ty -> of_type_with_prec 1 ty) typs - let of_typed_id id typ = of_id id ^^^ colon ^//^ group (of_type typ) + let of_typed_ann_id id typ = of_ann_id id ^^^ colon ^//^ group (of_type typ) let rec of_literal lit = let rec walk p = function @@ -202,14 +234,14 @@ struct let of_payload = function | Ast.MLit lit -> of_literal lit - | Ast.MVar id -> of_id id + | Ast.MVar id -> of_ann_id id let of_pattern pat = let rec of_pattern_aux ~top_parens = function | Ast.Wildcard -> !^"_" - | Ast.Binder id -> of_id id + | Ast.Binder id -> of_ann_id id | Ast.Constructor (constr_id, pats) -> - let constr_id = of_id constr_id in + let constr_id = of_ann_id constr_id in if List.is_empty pats then constr_id else @@ -218,10 +250,10 @@ struct in of_pattern_aux ~top_parens:false pat - let rec of_expr (expr, _ann) = - match expr with + let rec of_expr (expr, _ann, comments) = + (match expr with | Ast.Literal lit -> of_literal lit - | Ast.Var id -> of_id id + | Ast.Var id -> of_ann_id id | Ast.Fun (id, typ, body) -> (* TODO: nested functions should not be indented: fun (a : String) => @@ -233,20 +265,20 @@ struct let body = of_expr body in (* fun ($id : $typ) => $body *) - fun_kwd ^^^ parens (of_typed_id id typ) ^^^ darrow ^^ indent (hardline ^^ body) + fun_kwd ^^^ parens (of_typed_ann_id id typ) ^^^ darrow ^^ indent (hardline ^^ body) | Ast.App (fid, args) -> - let fid = of_id fid - and args = of_ids args in + let fid = of_ann_id fid + and args = of_ann_ids args in fid ^//^ args | Ast.Builtin ((builtin, _ann), _types, typed_ids) -> let builtin = of_builtin builtin - and args = of_ids typed_ids in + and args = of_ann_ids typed_ids in builtin_kwd ^^^ builtin ^//^ args | Ast.Let (id, otyp, lhs, body) -> let id = match otyp with - | None -> of_id id - | Some typ -> of_typed_id id typ + | None -> of_ann_id id + | Some typ -> of_typed_ann_id id typ and lhs = of_expr lhs and body = of_expr body in (* @@ -271,27 +303,27 @@ struct *) (group (group (let_kwd ^^^ id ^^^ equals ^//^ lhs) ^/^ in_kwd)) ^/^ body | Ast.TFun (ty_var, body) -> - let ty_var = of_id ty_var + let ty_var = of_ann_id ty_var and body = of_expr body in (* tfun $ty_var => $body *) (* (^/^) -- means concat with _breakable_ space *) tfun_kwd ^^^ ty_var ^^^ darrow ^//^ body | Ast.TApp (id, typs) -> - let tfid = of_id id + let tfid = of_ann_id id (* TODO: remove unnecessary parens around primitive types: e.g. "Nat" does not need parens but "forall 'X. 'X" needs them in type applications *) and typs = separate_map space (fun typ -> parens @@ of_type typ) typs in at ^^ tfid ^//^ typs | Ast.MatchExpr (ident, branches) -> - match_kwd ^^^ of_id ident ^^^ with_kwd ^/^ + match_kwd ^^^ of_ann_id ident ^^^ with_kwd ^/^ separate_map hardline (fun (pat, e) -> group (pipe ^^^ of_pattern pat ^^^ darrow ^//^ group (of_expr e))) branches ^^ hardline ^^ end_kwd | Ast.Constr (id, typs, args) -> - let id = of_id id + let id = of_ann_id id (* TODO: remove unnecessary parens around primitive types *) - and args_doc = of_ids args in + and args_doc = of_ann_ids args in if Base.List.is_empty typs then if Base.List.is_empty args then id else id ^//^ args_doc else @@ -310,44 +342,46 @@ struct rbrace | Fixpoint _ -> failwith "Fixpoints cannot appear in user contracts" | GasExpr _ -> failwith "Gas annotations cannot appear in user contracts's expressions" + ) |> wrap_comments comments let of_map_access map keys = - let map = of_id map - and keys = concat_map (fun k -> brackets @@ of_id k) keys in + let map = of_ann_id map + and keys = concat_map (fun k -> brackets @@ of_ann_id k) keys in map ^^ keys - let rec of_stmt (stmt, _ann) = match stmt with + let rec of_stmt (stmt, _ann, comments) = + (match stmt with | Ast.Load (id, field) -> - of_id id ^^^ rev_arrow ^//^ of_id field + of_ann_id id ^^^ rev_arrow ^//^ of_ann_id field | Ast.RemoteLoad (id, addr, field) -> - of_id id ^^^ blockchain_arrow ^//^ of_id addr ^^ dot ^^ of_id field + of_ann_id id ^^^ blockchain_arrow ^//^ of_ann_id addr ^^ dot ^^ of_ann_id field | Ast.Store (field, id) -> - of_id field ^^^ assign ^//^ of_id id + of_ann_id field ^^^ assign ^//^ of_ann_id id | Ast.Bind (id, expr) -> - of_id id ^^^ equals ^//^ of_expr expr + of_ann_id id ^^^ equals ^//^ of_expr expr | Ast.MapUpdate (map, keys, mode) -> (* m[k1][k2][..] := v OR delete m[k1][k2][...] *) (match mode with - | Some value -> of_map_access map keys ^^^ assign ^//^ of_id value + | Some value -> of_map_access map keys ^^^ assign ^//^ of_ann_id value | None -> delete_kwd ^^^ of_map_access map keys) | Ast.MapGet (id, map, keys, mode) -> (* v <- m[k1][k2][...] OR b <- exists m[k1][k2][...] *) (* If the bool is set, then we interpret this as value retrieve, otherwise as an "exists" query. *) if mode then - of_id id ^^^ rev_arrow ^//^ of_map_access map keys + of_ann_id id ^^^ rev_arrow ^//^ of_map_access map keys else - of_id id ^^^ rev_arrow ^//^ exists_kwd ^^^ of_map_access map keys + of_ann_id id ^^^ rev_arrow ^//^ exists_kwd ^^^ of_map_access map keys | Ast.RemoteMapGet (id, addr, map, keys, mode) -> (* v <-& adr.m[k1][k2][...] OR b <-& exists adr.m[k1][k2][...] *) (* If the bool is set, then we interpret this as value retrieve, otherwise as an "exists" query. *) if mode then - of_id id ^^^ blockchain_arrow ^//^ of_id addr ^^ dot ^^ of_map_access map keys + of_ann_id id ^^^ blockchain_arrow ^//^ of_ann_id addr ^^ dot ^^ of_map_access map keys else - of_id id ^^^ blockchain_arrow ^//^ exists_kwd ^^^ of_id addr ^^ dot ^^ of_map_access map keys + of_ann_id id ^^^ blockchain_arrow ^//^ exists_kwd ^^^ of_ann_id addr ^^ dot ^^ of_map_access map keys | Ast.MatchStmt (id, branches) -> - match_kwd ^^^ of_id id ^^^ with_kwd ^/^ + match_kwd ^^^ of_ann_id id ^^^ with_kwd ^/^ separate_map hardline (fun (pat, stmts) -> group (pipe ^^^ of_pattern pat ^^^ darrow ^//^ group (of_stmts stmts))) branches @@ -357,30 +391,31 @@ struct match query with | CurBlockNum -> blocknumber_kwd | ChainID -> chainid_kwd - | Timestamp ts -> timestamp_kwd ^^ parens (of_id ts) + | Timestamp ts -> timestamp_kwd ^^ parens (of_ann_id ts) | ReplicateContr (addr, init_params) -> - replicate_contract_kwd ^^ parens (of_id addr ^^ comma ^^^ of_id init_params) + replicate_contract_kwd ^^ parens (of_ann_id addr ^^ comma ^^^ of_ann_id init_params) in - of_id id ^^^ blockchain_arrow ^//^ query + of_ann_id id ^^^ blockchain_arrow ^//^ query | Ast.TypeCast (id, addr, typ) -> - of_id id ^^^ blockchain_arrow ^//^ of_id addr ^^^ as_kwd ^^^ of_type typ + of_ann_id id ^^^ blockchain_arrow ^//^ of_ann_id addr ^^^ as_kwd ^^^ of_type typ | Ast.AcceptPayment -> accept_kwd | Ast.Iterate (arg_list, proc) -> (* forall l p *) - forall_kwd ^//^ of_id arg_list ^//^ of_id proc + forall_kwd ^//^ of_ann_id arg_list ^//^ of_ann_id proc | Ast.SendMsgs msgs -> - send_kwd ^//^ of_id msgs + send_kwd ^//^ of_ann_id msgs | Ast.CreateEvnt events -> - event_kwd ^//^ of_id events + event_kwd ^//^ of_ann_id events | Ast.CallProc (proc, args) -> - if List.is_empty args then of_id proc - else of_id proc ^//^ of_ids args + if List.is_empty args then of_ann_id proc + else of_ann_id proc ^//^ of_ann_ids args | Ast.Throw oexc -> (match oexc with | None -> throw_kwd - | Some exc -> throw_kwd ^//^ of_id exc) + | Some exc -> throw_kwd ^//^ of_ann_id exc) | Ast.GasStmt _ -> failwith "Gas annotations cannot appear in user contracts's statements" + ) |> wrap_comments comments and of_stmts stmts = separate_map (semi ^^ hardline) (fun s -> of_stmt s) stmts @@ -391,13 +426,13 @@ struct lparen (separate_map (comma ^^ sep) - (fun (p, typ) -> of_typed_id p typ) + (fun (p, typ) -> of_typed_ann_id p typ) typed_params) rparen let of_component Ast.{comp_type; comp_name; comp_params; comp_body} = let comp_type = !^(Syntax.component_type_to_string comp_type) - and comp_name = of_id comp_name + and comp_name = of_ann_id comp_name and comp_params = of_parameters comp_params ~sep:(break 1) and comp_body = of_stmts comp_body in group (comp_type ^^^ comp_name ^//^ comp_params) ^^ @@ -405,7 +440,7 @@ struct end_kwd let of_ctr_def Ast.{cname; c_arg_types} = - let constructor_name = of_id cname + let constructor_name = of_ann_id cname and constructor_args_types = (* TODO: break sequences of long types (e.g. ByStr20 with contract ................... end Uint256 is unreadable) *) of_types ~sep:(break 1) c_arg_types @@ -419,12 +454,12 @@ struct | Ast.LibVar (definition, otyp, expr) -> let definition = match otyp with - | None -> of_id definition - | Some typ -> of_typed_id definition typ + | None -> of_ann_id definition + | Some typ -> of_typed_ann_id definition typ and expr = of_expr expr in let_kwd ^^^ definition ^^^ equals ^//^ expr | Ast.LibTyp (typ_name, constr_defs) -> - let typ_name = of_id typ_name + let typ_name = of_ann_id typ_name and constr_defs = separate_map hardline (fun cd -> pipe ^^^ of_ctr_def cd) constr_defs in @@ -432,7 +467,7 @@ struct constr_defs let of_library Ast.{lname; lentries} = - library_kwd ^^^ of_id lname ^^ + library_kwd ^^^ of_ann_id lname ^^ if List.is_empty lentries then hardline else let lentries = @@ -444,12 +479,12 @@ struct twice hardline ^^ lentries ^^ hardline let of_contract Ast.{cname; cparams; cconstraint; cfields; ccomps} = - let cname = of_id cname + let cname = of_ann_id cname and cparams = of_parameters cparams ~sep:hardline and cconstraint = let true_ctr = Lit.LType.TIdentifier.Name.parse_simple_name "True" in match cconstraint with - | (Ast.Literal (Lit.ADTValue (c, [], [])), _annot) when [%equal: _] c true_ctr -> + | (Ast.Literal (Lit.ADTValue (c, [], [])), _annot, _comment) when [%equal: _] c true_ctr -> (* trivially True contract constraint does not get rendered *) empty | _ -> @@ -462,7 +497,7 @@ struct separate_map (twice hardline) (fun (field, typ, init) -> - field_kwd ^^^ of_typed_id field typ ^^^ equals ^//^ of_expr init) + field_kwd ^^^ of_typed_ann_id field typ ^^^ equals ^//^ of_expr init) cfields ^^ twice hardline and ccomps = @@ -477,8 +512,8 @@ struct let imports = let import_lib (lib, onamespace) = match onamespace with - | None -> of_id lib - | Some namespace -> of_id lib ^^^ as_kwd ^^^ of_id namespace + | None -> of_ann_id lib + | Some namespace -> of_ann_id lib ^^^ as_kwd ^^^ of_ann_id namespace in let imported_libs = separate_map (hardline) (fun imp -> import_lib imp) elibs diff --git a/src/runners/scilla_fmt.ml b/src/runners/scilla_fmt.ml index eee371cb2..abc9c1b98 100644 --- a/src/runners/scilla_fmt.ml +++ b/src/runners/scilla_fmt.ml @@ -92,6 +92,9 @@ let scilla_sexp_fmt deannotated human_readable file = let scilla_source_code_fmt file = let open FilePath in let open StdlibTracker in + let tr () = + ExtendedSyntax.LocalLiteralTransformer.mk (FEParser.get_comments ()) + in if check_extension file file_extn_library then (* library modules *) (* file @@ -100,12 +103,14 @@ let scilla_source_code_fmt file = |> *) failwith "Formatting of Scilla library modules is not implemented yet" else if check_extension file file_extn_contract then - (* contract modules *) + (* contract modules *) file |> FEParser.parse_cmodule |> unpack_ast_exn + |> ExtendedSyntax.LocalLiteralTransformer.extend_cmodule (tr ()) |> Formatter.LocalLiteralSyntax.contract_to_string else if check_extension file file_extn_expression then (* expressions *) file |> FEParser.parse_expr_from_file |> unpack_ast_exn + |> ExtendedSyntax.LocalLiteralTransformer.extend_expr (tr ()) |> Formatter.LocalLiteralSyntax.expr_to_string else fatal_error (mk_error0 ~kind:"Unknown file extension" ?inst:None)