summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--io_uring/napi.c19
-rw-r--r--io_uring/napi.h6
2 files changed, 9 insertions, 16 deletions
diff --git a/io_uring/napi.c b/io_uring/napi.c
index 921de9de8d75..5e2299e7ff8e 100644
--- a/io_uring/napi.c
+++ b/io_uring/napi.c
@@ -38,22 +38,14 @@ static inline ktime_t net_to_ktime(unsigned long t)
return ns_to_ktime(t << 10);
}
-void __io_napi_add(struct io_ring_ctx *ctx, struct socket *sock)
+int __io_napi_add_id(struct io_ring_ctx *ctx, unsigned int napi_id)
{
struct hlist_head *hash_list;
- unsigned int napi_id;
- struct sock *sk;
struct io_napi_entry *e;
- sk = sock->sk;
- if (!sk)
- return;
-
- napi_id = READ_ONCE(sk->sk_napi_id);
-
/* Non-NAPI IDs can be rejected. */
if (napi_id < MIN_NAPI_ID)
- return;
+ return -EINVAL;
hash_list = &ctx->napi_ht[hash_min(napi_id, HASH_BITS(ctx->napi_ht))];
@@ -62,13 +54,13 @@ void __io_napi_add(struct io_ring_ctx *ctx, struct socket *sock)
if (e) {
WRITE_ONCE(e->timeout, jiffies + NAPI_TIMEOUT);
rcu_read_unlock();
- return;
+ return -EEXIST;
}
rcu_read_unlock();
e = kmalloc(sizeof(*e), GFP_NOWAIT);
if (!e)
- return;
+ return -ENOMEM;
e->napi_id = napi_id;
e->timeout = jiffies + NAPI_TIMEOUT;
@@ -77,12 +69,13 @@ void __io_napi_add(struct io_ring_ctx *ctx, struct socket *sock)
if (unlikely(io_napi_hash_find(hash_list, napi_id))) {
spin_unlock(&ctx->napi_lock);
kfree(e);
- return;
+ return -EEXIST;
}
hlist_add_tail_rcu(&e->node, hash_list);
list_add_tail_rcu(&e->list, &ctx->napi_list);
spin_unlock(&ctx->napi_lock);
+ return 0;
}
static void __io_napi_remove_stale(struct io_ring_ctx *ctx)
diff --git a/io_uring/napi.h b/io_uring/napi.h
index fd275ef0456d..4ae622f37b30 100644
--- a/io_uring/napi.h
+++ b/io_uring/napi.h
@@ -15,7 +15,7 @@ void io_napi_free(struct io_ring_ctx *ctx);
int io_register_napi(struct io_ring_ctx *ctx, void __user *arg);
int io_unregister_napi(struct io_ring_ctx *ctx, void __user *arg);
-void __io_napi_add(struct io_ring_ctx *ctx, struct socket *sock);
+int __io_napi_add_id(struct io_ring_ctx *ctx, unsigned int napi_id);
void __io_napi_busy_loop(struct io_ring_ctx *ctx, struct io_wait_queue *iowq);
int io_napi_sqpoll_busy_poll(struct io_ring_ctx *ctx);
@@ -48,8 +48,8 @@ static inline void io_napi_add(struct io_kiocb *req)
return;
sock = sock_from_file(req->file);
- if (sock)
- __io_napi_add(ctx, sock);
+ if (sock && sock->sk)
+ __io_napi_add_id(ctx, READ_ONCE(sock->sk->sk_napi_id));
}
#else