diff --git a/us.c b/us.c index 4048275..68e9605 100644 --- a/us.c +++ b/us.c @@ -51,6 +51,7 @@ #define MAX_HASH 1024 #define CONF_LINE_MAX 1024 #define GROUPS_MAX 256 +#define STR_MAX 1024 #define FLAG_PERSIST 0x1 #define FLAG_NOPASS 0x2 #define FLAG_NOLOG 0x4 @@ -69,6 +70,14 @@ struct config { int env_n; }; +struct user_info { + union { + struct passwd pw; + struct group gr; + } d; + char str[STR_MAX]; +}; + static void *emalloc(size_t); static char *estrdup(const char *); void *erealloc(void *, size_t); @@ -76,8 +85,8 @@ static void usage(void); static void die(const char *, ...); static int perm_set(struct passwd *, struct group *); static int authenticate(uid_t, char); -static struct passwd* user_to_passwd(const char *); -static struct group* group_to_grp(const char *); +static struct passwd* user_to_passwd(const char *, struct user_info *); +static struct group* group_to_grp(const char *, struct user_info *); static int get_config(struct config **, int *); extern char **environ; @@ -88,6 +97,7 @@ int main(int argc, char *argv[]) char *t_usr = "root", *t_grp = NULL; struct passwd *t_pw; struct group *t_gr; + struct user_info t_gr_info = {0}, t_pw_info = {0}; int opt, err; int shellflag = 0, envflag = 0, askpass = 0; while ((opt = getopt(argc, argv, "Au:g:C:se")) != -1) { @@ -120,7 +130,9 @@ int main(int argc, char *argv[]) /* Get user info */ char *shell; uid_t ruid = getuid(); - struct passwd *my_pw = getpwuid(ruid); + struct passwd *my_pw = NULL; + struct user_info my_info = {0}; + getpwuid_r(ruid, &my_info.d.pw, my_info.str, STR_MAX, &my_pw); if (!my_pw) { fprintf(stderr, "getpwid: %s\n", strerror(errno)); return errno; @@ -132,10 +144,10 @@ int main(int argc, char *argv[]) die("getgroups:"); /* Get target user and group info */ - t_pw = user_to_passwd(t_usr); + t_pw = user_to_passwd(t_usr, &t_pw_info); if (!t_pw) die("user_to_passwd:"); - t_gr = group_to_grp(t_grp); + t_gr = group_to_grp(t_grp, &t_gr_info); /* Don't have to wait for children */ struct sigaction sa = {0}; @@ -153,45 +165,44 @@ int main(int argc, char *argv[]) struct env_elem *env_extra = NULL; struct config *conf = NULL; int conf_num, conf_flags = 0, env_extra_n = 0; - if (get_config(&conf, &conf_num) == -1) - die("get_config: invalid arguments"); + if (get_config(&conf, &conf_num) <= 0) + die("get_config: invalid config"); int here = 0; for (int i = 0; i < conf_num; i++) { struct passwd *who_pw, *as_pw; struct group *who_gr, *as_gr; + struct user_info who_info = {0}, as_info = {0}; int who_usr = conf[i].who[0] == ':' ? 0 : 1; int as_usr = conf[i].as[0] == ':' ? 0 : 1; if (who_usr) { - who_pw = user_to_passwd(conf[i].who); - if (my_pw->pw_uid != who_pw->pw_uid) { - free(who_pw); + who_pw = user_to_passwd(conf[i].who, &who_info); + if (!who_pw) + die("%s not a valid user", conf[i].who); + if (my_pw->pw_uid != who_pw->pw_uid) continue; - } } else { - who_gr = group_to_grp(conf[i].who); + who_gr = group_to_grp(conf[i].who, &who_info); + if (!who_gr) + die("%s not a valid group", conf[i].who); gid_t w_gid = who_gr->gr_gid; int x = 0; for (; x < n_groups && w_gid != my_groups[x]; x++); - if (w_gid != my_groups[x]) { - free(who_gr); + if (w_gid != my_groups[x]) continue; - } } if (as_usr) { - as_pw = user_to_passwd(conf[i].as); + as_pw = user_to_passwd(conf[i].as, &as_info); if (!as_pw) die("%s not a valid user", conf[i].as); - if (t_pw->pw_uid != as_pw->pw_uid) { - free(as_pw); + if (t_pw->pw_uid != as_pw->pw_uid) continue; - } } else if (t_gr) { - as_gr = group_to_grp(conf[i].as); - if (t_gr->gr_gid != as_gr->gr_gid) { - free(as_gr); + as_gr = group_to_grp(conf[i].as, &as_info); + if (!as_gr) + die("%s not a valid group", conf[i].as); + if (t_gr->gr_gid != as_gr->gr_gid) continue; - } } else { continue; } @@ -210,6 +221,16 @@ int main(int argc, char *argv[]) } } + /* We don't need conf anymore */ + for (int i = 0; i < conf_num; i++) { + free(conf[i].who); + free(conf[i].as); + if (conf[i].env && conf[i].env_n) + free(conf[i].env); + } + free(conf); + + /* No configuration was fount, can't proceed */ if (!here) die("no rule found for user %s", my_name); @@ -228,7 +249,7 @@ int main(int argc, char *argv[]) /* Set argc and argv */ int c_argc = argc - optind; - char **c_argv; + char **c_argv = NULL; if (c_argc) { c_argv = emalloc(sizeof(char *) * (c_argc + 1)); for (int i = 0; optind < argc; optind++, i++) @@ -487,66 +508,56 @@ static int authenticate(uid_t uid, char ask) return 0; } -static struct passwd* user_to_passwd(const char *user) +static struct passwd* user_to_passwd(const char *user, struct user_info *info) { if (!user) { errno = EINVAL; return NULL; } - struct passwd *pw, *pr; + struct passwd *pw; long uid_l; errno = 0; if (user[0] != '#') { - pw = getpwnam(user); + getpwnam_r(user, &(info->d.pw), info->str, STR_MAX, &pw); } else { uid_l = strtol(&user[1], NULL, 10); if (uid_l < 0 || errno) { errno = errno ? errno : EINVAL; return NULL; } - pw = getpwuid((uid_t)uid_l); + getpwuid_r((uid_t)uid_l, &(info->d.pw), info->str, STR_MAX, &pw); } - if (!pw) { - if (!errno) - errno = EINVAL; - return NULL; - } - pr = emalloc(sizeof(struct passwd)); - memcpy(pr, pw, sizeof(struct passwd)); - return pr; + if (!pw && !errno) + errno = EINVAL; + return pw; } -static struct group* group_to_grp(const char *group) +static struct group* group_to_grp(const char *group, struct user_info *info) { if (!group) { errno = EINVAL; return NULL; } - struct group *gr, *rr; + struct group *gr; long gid_l; errno = 0; if (group[0] != '#') { - gr = getgrnam(group); + getgrnam_r(group, &(info->d.gr), info->str, STR_MAX, &gr); } else { gid_l = strtol(&group[1], NULL, 10); if (gid_l < 0 || errno) { errno = errno ? errno : EINVAL; return NULL; } - gr = getgrgid((gid_t)gid_l); + getgrgid_r((gid_t)gid_l, &(info->d.gr), info->str, STR_MAX, &gr); } - if (!gr) { - if (!errno) - errno = EINVAL; - return NULL; - } - rr = emalloc(sizeof(struct group)); - memcpy(rr, gr, sizeof(struct group)); - return rr; + if (!gr && !errno) + errno = EINVAL; + return gr; } void die(const char *fmt, ...) @@ -690,5 +701,5 @@ static int get_config(struct config **conf, int *num) *num += 1; } fclose(fp); - return 0; + return *num; }