diff --git a/main.go b/main.go index f526003b..49bbc831 100644 --- a/main.go +++ b/main.go @@ -7,10 +7,12 @@ package main // interface. import ( + "errors" "flag" "fmt" "log" "os" + "path" "runtime/pprof" "strconv" "syscall" @@ -144,27 +146,33 @@ func main() { func initRuntime() (*eval.Evaler, *api.Client) { var dataDir string var err error - if *dbpath == "" || *sockpath == "" { - // Determine default paths for database and socket. - dataDir, err = storedefs.EnsureDataDir() - if err != nil { - fmt.Fprintln(os.Stderr, "warning: cannot create data dir ~/.elvish") - } else { - if *dbpath == "" { - *dbpath = dataDir + "/db" - } - if *sockpath == "" { - *sockpath = dataDir + "/sock" - } + + // Determine data directory. + dataDir, err = storedefs.EnsureDataDir() + if err != nil { + fmt.Fprintln(os.Stderr, "warning: cannot create data directory ~/.elvish") + } else { + if *dbpath == "" { + *dbpath = dataDir + "/db" } } + // Determine runtime directory. + runDir, err := getSecureRunDir() + if err != nil { + fmt.Fprintln(os.Stderr, "cannot get runtime dir /tmp/elvish-$uid, falling back to data dir ~/.elvish:", err) + runDir = dataDir + } + if *sockpath == "" { + *sockpath = runDir + "/sock" + } + toSpawn := &daemon.Daemon{ Forked: *forked, BinPath: *binpath, DbPath: *dbpath, SockPath: *sockpath, - LogPathPrefix: dataDir + "/daemon.log.", + LogPathPrefix: runDir + "/daemon.log.", } var cl *api.Client if *sockpath != "" && *dbpath != "" { @@ -209,3 +217,34 @@ spawnDaemonEnd: return eval.NewEvaler(cl, toSpawn, dataDir), cl } + +var ( + ErrBadOwner = errors.New("bad owner") + ErrBadPermission = errors.New("bad permission") +) + +// getSecureRunDir stats /tmp/elvish-$uid, creating it if it doesn't yet exist, +// and return the directory name if it has the correct owner and permission. +func getSecureRunDir() (string, error) { + uid := syscall.Getuid() + + runDir := path.Join(os.TempDir(), fmt.Sprintf("elvish-%d", uid)) + err := os.MkdirAll(runDir, 0700) + if err != nil { + return "", fmt.Errorf("mkdir: %v", err) + } + + var stat syscall.Stat_t + err = syscall.Stat(runDir, &stat) + if err != nil { + return "", fmt.Errorf("stat: %v", err) + } + + if int(stat.Uid) != uid { + return "", ErrBadOwner + } + if stat.Mode&077 != 0 { + return "", ErrBadPermission + } + return runDir, err +}