diff --git a/execution.go b/execution.go index 423e6bd..5ac8254 100644 --- a/execution.go +++ b/execution.go @@ -57,7 +57,7 @@ func HandleExecuteRequest(receipt MsgReceipt) { content["execution_count"] = ExecCounter // Do the compilation/execution magic. - val, err, stderr := REPLSession.Eval(code) + val, stderr, err := REPLSession.Eval(code) if err == nil { content["status"] = "ok" diff --git a/gophernotes_test.go b/gophernotes_test.go index ab79c54..36d302f 100644 --- a/gophernotes_test.go +++ b/gophernotes_test.go @@ -31,7 +31,7 @@ func TestRun_import(t *testing.T) { } for _, code := range codes { - _, err, _ := s.Eval(code) + _, _, err := s.Eval(code) noError(t, err) } } @@ -52,7 +52,7 @@ func TestRun_QuickFix_evaluated_but_not_used(t *testing.T) { } for _, code := range codes { - _, err, _ := s.Eval(code) + _, _, err := s.Eval(code) noError(t, err) } } @@ -70,7 +70,7 @@ func TestRun_QuickFix_used_as_value(t *testing.T) { } for _, code := range codes { - _, err, _ := s.Eval(code) + _, _, err := s.Eval(code) noError(t, err) } } @@ -90,7 +90,7 @@ func TestRun_Copy(t *testing.T) { } for _, code := range codes { - _, err, _ := s.Eval(code) + _, _, err := s.Eval(code) noError(t, err) } } @@ -107,7 +107,7 @@ func TestRun_Const(t *testing.T) { } for _, code := range codes { - _, err, _ := s.Eval(code) + _, _, err := s.Eval(code) noError(t, err) } } diff --git a/internal/repl/repl.go b/internal/repl/repl.go index 5f4081b..33e9be8 100644 --- a/internal/repl/repl.go +++ b/internal/repl/repl.go @@ -2,7 +2,6 @@ package replpkg import ( "bytes" - "flag" "fmt" "io/ioutil" "os" @@ -23,36 +22,12 @@ import ( // Importing this package installs Import as go/types.DefaultImport. "golang.org/x/tools/imports" - "github.com/mitchellh/go-homedir" "github.com/motemen/go-quickfix" ) -const version = "0.2.5" -const printerName = "__gore_p" +const printerName = "__gophernotes" -var ( - flagAutoImport = flag.Bool("autoimport", false, "formats and adjusts imports automatically") - flagExtFiles = flag.String("context", "", - "import packages, functions, variables and constants from external golang source files") - flagPkg = flag.String("pkg", "", "specify a package where the session will be run inside") -) - -func homeDir() (home string, err error) { - home = os.Getenv("GORE_HOME") - if home != "" { - return - } - - home, err = homedir.Dir() - if err != nil { - return - } - - home = filepath.Join(home, ".gore") - return -} - -// Session encodes info about the current REPL session +// Session encodes info about the current REPL session. type Session struct { FilePath string File *ast.File @@ -82,7 +57,7 @@ func main() { ` // printerPkgs is a list of packages that provides -// pretty printing function. Preceding first. +// pretty printing function. var printerPkgs = []struct { path string code string @@ -94,7 +69,6 @@ var printerPkgs = []struct { // NewSession initiates a new REPL func NewSession() (*Session, error) { - var err error s := &Session{ Fset: token.NewFileSet(), @@ -103,6 +77,7 @@ func NewSession() (*Session, error) { }, } + var err error s.FilePath, err = tempFile() if err != nil { return nil, err @@ -117,12 +92,11 @@ func NewSession() (*Session, error) { } debugf("could not import %q: %s", pp.path, err) } - if initialSource == "" { - return nil, fmt.Errorf(`Could not load pretty printing package (even "fmt"; something is wrong)`) + return nil, fmt.Errorf("Could not load pretty printing package") } - s.File, err = parser.ParseFile(s.Fset, "gore_session.go", initialSource, parser.Mode(0)) + s.File, err = parser.ParseFile(s.Fset, "gophernotes_session.go", initialSource, parser.Mode(0)) if err != nil { return nil, err } @@ -136,22 +110,22 @@ func (s *Session) mainFunc() *ast.FuncDecl { return s.File.Scope.Lookup("main").Decl.(*ast.FuncDecl) } -// Run calls "go run" with appropriate files appended -func (s *Session) Run() ([]byte, error, bytes.Buffer) { +// Run calls "go run" with appropriate files appended. +func (s *Session) Run() ([]byte, bytes.Buffer, error) { f, err := os.Create(s.FilePath) if err != nil { - return []byte{}, err, bytes.Buffer{} + return nil, bytes.Buffer{}, err } err = printer.Fprint(f, s.Fset, s.File) if err != nil { - return []byte{}, err, bytes.Buffer{} + return nil, bytes.Buffer{}, err } return goRun(append(s.ExtraFilePaths, s.FilePath)) } -// tempFile prepares the temporary session file for the REPL +// tempFile prepares the temporary session file for the REPL. func tempFile() (string, error) { dir, err := ioutil.TempDir("", "") if err != nil { @@ -163,10 +137,10 @@ func tempFile() (string, error) { return "", err } - return filepath.Join(dir, "gore_session.go"), nil + return filepath.Join(dir, "gophernotes_session.go"), nil } -func goRun(files []string) ([]byte, error, bytes.Buffer) { +func goRun(files []string) ([]byte, bytes.Buffer, error) { var stderr bytes.Buffer @@ -174,11 +148,9 @@ func goRun(files []string) ([]byte, error, bytes.Buffer) { debugf("go %s", strings.Join(args, " ")) cmd := exec.Command("go", args...) cmd.Stdin = os.Stdin - //cmd.Stdout = os.Stdout - //cmd.Stderr = os.Stderr cmd.Stderr = &stderr out, err := cmd.Output() - return out, err, stderr + return out, stderr, err } func (s *Session) evalExpr(in string) (ast.Expr, error) { @@ -246,10 +218,10 @@ func (s *Session) appendStatements(stmts ...ast.Stmt) { s.mainBody.List = append(s.mainBody.List, stmts...) } -// Error is an exported type error +// Error is an exported error. type Error string -// ErrContinue and ErrQuit are specific exported errors +// ErrContinue and ErrQuit are specific exported error types. const ( ErrContinue Error = "" ErrQuit Error = "" @@ -285,7 +257,7 @@ func (s *Session) reset() error { return err } - file, err := parser.ParseFile(s.Fset, "gore_session.go", source, parser.Mode(0)) + file, err := parser.ParseFile(s.Fset, "gophernotes_session.go", source, parser.Mode(0)) if err != nil { return err } @@ -297,34 +269,56 @@ func (s *Session) reset() error { } // Eval handles the evaluation of code parsed from received messages -func (s *Session) Eval(in string) (string, error, bytes.Buffer) { +func (s *Session) Eval(in string) (string, bytes.Buffer, error) { debugf("eval >>> %q", in) s.clearQuickFix() s.storeMainBody() - for _, command := range commands { - arg := strings.TrimPrefix(in, ":"+command.name) - if arg == in { + // Split the lines of the input to check for special commands. + inLines := strings.Split(in, "\n") + var nonImportLines []string + for _, line := range inLines { + + // Extract non-special lines. + if !strings.HasPrefix(line, "import") && !strings.HasPrefix(line, ":") { + nonImportLines = append(nonImportLines, line) continue } - if arg == "" || strings.HasPrefix(arg, " ") { - arg = strings.TrimSpace(arg) + // Process special commands. + for _, command := range commands { - result, err := command.action(s, arg) - if err != nil { - if err == ErrQuit { - return "", err, bytes.Buffer{} - } - errorf("%s: %s", command.name, err) + // Extract any argument provided with the special command. + arg := strings.TrimPrefix(line, ":"+command.name) + if command.name == "import" { + arg = strings.TrimPrefix(arg, "import") + } + if arg == line { + continue } - s.doQuickFix() - return result, nil, bytes.Buffer{} + // Apply the action associated with the special command. + if arg == "" || strings.HasPrefix(arg, " ") { + arg = strings.TrimSpace(arg) + _, err := command.action(s, arg) + if err != nil { + if err == ErrQuit { + return "", bytes.Buffer{}, err + } + errorf("%s: %s", command.name, err) + } + } } } + // Join the non-special lines back together for evaluation. + in = strings.Join(nonImportLines, "\n") + if len(in) == 0 { + s.doQuickFix() + return "", bytes.Buffer{}, nil + } + if _, err := s.evalExpr(in); err != nil { debugf("expr :: err = %s", err) @@ -363,18 +357,15 @@ func (s *Session) Eval(in string) (string, error, bytes.Buffer) { if err = s.importFile(functproxy); err != nil { errorf("%s", err) if _, ok := err.(scanner.ErrorList); ok { - return "", ErrContinue, bytes.Buffer{} + return "", bytes.Buffer{}, ErrContinue } } } } - if *flagAutoImport { - s.fixImports() - } s.doQuickFix() - output, err, strerr := s.Run() + output, strerr, err := s.Run() if err != nil { if exitErr, ok := err.(*exec.ExitError); ok { // if failed with status 2, remove the last statement @@ -388,7 +379,7 @@ func (s *Session) Eval(in string) (string, error, bytes.Buffer) { errorf("%s", err) } - return string(output), err, strerr + return string(output), strerr, err } // storeMainBody stores current state of code so that it can be restored