diff --git a/run/run.go b/run/run.go index 2cae573f..4b0db02f 100644 --- a/run/run.go +++ b/run/run.go @@ -96,7 +96,7 @@ func Main() { if len(args) == 1 { if *cmd { - ev.SourceText("code from -c", args[0]) + evalText(ev, "code from -c", args[0]) } else { script(ev, args[0]) } @@ -119,26 +119,43 @@ func rescue() { } func script(ev *eval.Evaler, fname string) { - if source(ev, fname, false) != nil { + if !source(ev, fname, false) { os.Exit(1) } } -func source(ev *eval.Evaler, fname string, notexistok bool) error { +func source(ev *eval.Evaler, fname string, notexistok bool) bool { src, err := readFileUTF8(fname) if err != nil { if notexistok && os.IsNotExist(err) { - return nil + return true } fmt.Fprintln(os.Stderr, err) - return err + return false } - err = ev.SourceText(fname, src) + return evalText(ev, fname, src) +} + +// evalText is like eval.Evaler.SourceText except that it reports errors. +func evalText(ev *eval.Evaler, name, src string) bool { + n, err := parse.Parse(src) if err != nil { - printError(err, fname, "Error", src) + printError(err, name, "Parse error", src) + return false } - return err + + op, err := ev.Compile(n, name, src) + if err != nil { + printError(err, name, "Compile error", src) + return false + } + err = ev.Eval(name, src, op) + if err != nil { + printError(err, name, "Exception", src) + return false + } + return true } func readFileUTF8(fname string) (string, error) { @@ -203,21 +220,7 @@ func interact(ev *eval.Evaler, st *store.Store) { // No error; reset cooldown. cooldown = time.Second - n, err := parse.Parse(line) - if err != nil { - printError(err, "[interactive]", "Parse error", line) - continue - } - - op, err := ev.Compile(n, "[interactive]", line) - if err != nil { - printError(err, "[interactive]", "Compile error", line) - continue - } - err = ev.Eval("[interactive]", line, op) - if err != nil { - printError(err, "[interactive]", "Exception", line) - } + evalText(ev, "[interactive]", line) } } @@ -284,13 +287,13 @@ func printError(err error, srcname, errtype, src string) { case *util.TracebackError: fmt.Fprintln(os.Stderr, err.Pprint()) default: - printErrorString(err.Error()) + printErrorString(errtype, err.Error()) } } -func printErrorString(s string) { +func printErrorString(errtype, s string) { if sys.IsATTY(2) { s = "\033[1;31m" + s + "\033[m" } - fmt.Fprintln(os.Stderr, s) + fmt.Fprintln(os.Stderr, errtype+": "+s) }