Skip to content

Commit

Permalink
pidfd: Allow processes to have common dead pidfd
Browse files Browse the repository at this point in the history
This patch ensures that the process that creates the tmp process is the
one that kills and waits for it when all pidfds have been opened.
We do this by keeping track of the count of dead pidfds that each
process has opened.
When the count for the creator of the tmp process reaches 0, it waits
for all other processes to open pidfds and then kills and waits for the
tmp process.

Fixes: checkpoint-restore#2496

Signed-off-by: Bhavik Sachdev <[email protected]>
  • Loading branch information
bsach64 committed Oct 28, 2024
1 parent 1b33428 commit 9810c9e
Showing 1 changed file with 79 additions and 6 deletions.
85 changes: 79 additions & 6 deletions criu/pidfd.c
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "common/list.h"
#include "common/lock.h"
#include "imgset.h"
#include "pidfd.h"
Expand All @@ -10,6 +11,7 @@
#include <signal.h>
#include "common/bug.h"
#include "rst-malloc.h"
#include <unistd.h>

#undef LOG_PREFIX
#define LOG_PREFIX "pidfd: "
Expand All @@ -23,11 +25,19 @@ struct pidfd_info {
struct file_desc d;
};

struct dead_count {
size_t count;
int virt;
struct list_head list;
};

struct dead_pidfd {
unsigned int ino;
int pid;
size_t count;
int creator;
size_t total_count;
mutex_t pidfd_lock;
struct list_head count_per_process;
struct hlist_node hash;
};

Expand All @@ -49,6 +59,30 @@ int init_dead_pidfd_hash(void)
return 0;
}

static struct dead_count *init_dead_count(int pidfd_owner)
{
struct dead_count *spec_count;
spec_count = shmalloc(sizeof(*spec_count));
if (!spec_count) {
pr_err("Could not allocate shared memory..\n");
return NULL;
}
spec_count->count = 1;
spec_count->virt = pidfd_owner;
return spec_count;
}

static struct dead_count *lookup_count_per_process(struct dead_pidfd *dead, int pid)
{
struct dead_count *spec_count;
list_for_each_entry(spec_count, &dead->count_per_process, list) {
if (spec_count->virt == pid) {
return spec_count;
}
}
return NULL;
}

static struct dead_pidfd *lookup_dead_pidfd(unsigned int ino)
{
struct dead_pidfd *dead;
Expand Down Expand Up @@ -109,6 +143,7 @@ static int dump_one_pidfd(int pidfd, u32 id, const struct fd_parms *p)
pidfd_info.pidfe.id = id;
pidfd_info.pidfe.flags = (p->flags & ~O_RDWR);
pidfd_info.pidfe.fown = (FownEntry *)&p->fown;
pidfd_info.pidfe.pidfd_owner = pid_to_virt(p->pid);

fe.type = FD_TYPES__PIDFD;
fe.id = pidfd_info.pidfe.id;
Expand Down Expand Up @@ -200,6 +235,7 @@ static int open_one_pidfd(struct file_desc *d, int *new_fd)
{
struct pidfd_info *info;
struct dead_pidfd *dead = NULL;
struct dead_count *spec_count = NULL;
int pidfd;

info = container_of(d, struct pidfd_info, d);
Expand All @@ -215,15 +251,21 @@ static int open_one_pidfd(struct file_desc *d, int *new_fd)
dead = lookup_dead_pidfd(info->pidfe->ino);
BUG_ON(!dead);

spec_count = lookup_count_per_process(dead, vpid(current));
BUG_ON(!spec_count);

mutex_lock(&dead->pidfd_lock);
BUG_ON(dead->count == 0);
dead->count--;
BUG_ON(dead->total_count == 0);
BUG_ON(spec_count->count == 0);
dead->total_count--;
spec_count->count--;
if (dead->pid == -1) {
dead->pid = create_tmp_process();
if (dead->pid < 0) {
mutex_unlock(&dead->pidfd_lock);
goto err_close;
}
dead->creator = vpid(current);
}

pidfd = pidfd_open(dead->pid, info->pidfe->flags);
Expand All @@ -233,7 +275,13 @@ static int open_one_pidfd(struct file_desc *d, int *new_fd)
goto err_close;
}

if (dead->count == 0) {
if (spec_count->count == 0 && dead->creator == getpid()) {
mutex_unlock(&dead->pidfd_lock);
while (dead->total_count != 0) {
usleep(1);
}

mutex_lock(&dead->pidfd_lock);
if (free_dead_pidfd(dead)) {
pr_err("Failed to delete dead_pidfd struct\n");
mutex_unlock(&dead->pidfd_lock);
Expand Down Expand Up @@ -265,6 +313,7 @@ static int collect_one_pidfd(void *obj, ProtobufCMessage *msg, struct cr_img *i)
{
struct dead_pidfd *dead;
struct pidfd_info *info = obj;
struct dead_count *spec_count = NULL;

info->pidfe = pb_msg(msg, PidfdEntry);
pr_info_pidfd("Collected ", info->pidfe);
Expand All @@ -275,7 +324,19 @@ static int collect_one_pidfd(void *obj, ProtobufCMessage *msg, struct cr_img *i)
dead = lookup_dead_pidfd(info->pidfe->ino);
if (dead) {
mutex_lock(&dead->pidfd_lock);
dead->count++;
spec_count = lookup_count_per_process(dead, info->pidfe->pidfd_owner);
if (!spec_count) {
spec_count = init_dead_count(info->pidfe->pidfd_owner);
if (!spec_count) {
mutex_unlock(&dead->pidfd_lock);
return -1;
}
list_add(&spec_count->list, &dead->count_per_process);
} else {
spec_count->count++;
}
dead->total_count++;
pr_info("dead: ino: %d, count: %zu\n", dead->ino, dead->total_count);
mutex_unlock(&dead->pidfd_lock);
goto out;
}
Expand All @@ -287,15 +348,27 @@ static int collect_one_pidfd(void *obj, ProtobufCMessage *msg, struct cr_img *i)
}

INIT_HLIST_NODE(&dead->hash);
INIT_LIST_HEAD(&dead->count_per_process);
dead->ino = info->pidfe->ino;
dead->count = 1;
dead->total_count = 1;
dead->pid = -1;
mutex_init(&dead->pidfd_lock);

spec_count = init_dead_count(info->pidfe->pidfd_owner);
if (!spec_count)
return -1;

list_add(&spec_count->list, &dead->count_per_process);

mutex_lock(dead_pidfd_hash_lock);
hlist_add_head(&dead->hash, &dead_pidfd_hash[dead->ino % DEAD_PIDFD_HASH_SIZE]);
mutex_unlock(dead_pidfd_hash_lock);
out:
if (spec_count) {
list_for_each_entry(spec_count, &dead->count_per_process, list) {
pr_info("spec_count: virt: %d, count: %zu\n", spec_count->virt, spec_count->count);
}
}
return file_desc_add(&info->d, info->pidfe->id, &pidfd_desc_ops);
}

Expand Down

0 comments on commit 9810c9e

Please sign in to comment.