/*
* common code for virtio vsock
*
* Copyright (C) 2013-2015 Red Hat, Inc.
* Author: Asias He <asias@redhat.com>
* Stefan Hajnoczi <stefanha@redhat.com>
*
* This work is licensed under the terms of the GNU GPL, version 2.
*/
#include <linux/spinlock.h>
#include <linux/module.h>
#include <linux/ctype.h>
#include <linux/list.h>
#include <linux/virtio.h>
#include <linux/virtio_ids.h>
#include <linux/virtio_config.h>
#include <linux/virtio_vsock.h>
#include <net/sock.h>
#include <net/af_vsock.h>
#define CREATE_TRACE_POINTS
#include <trace/events/vsock_virtio_transport_common.h>
/* How long to wait for graceful shutdown of a connection */
#define VSOCK_CLOSE_TIMEOUT (8 * HZ)
static const struct virtio_transport *virtio_transport_get_ops(void)
{
const struct vsock_transport *t = vsock_core_get_transport();
return container_of(t, struct virtio_transport, transport);
}
static struct virtio_vsock_pkt *
virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info,
size_t len,
u32 src_cid,
u32 src_port,
u32 dst_cid,
u32 dst_port)
{
struct virtio_vsock_pkt *pkt;
int err;
pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
if (!pkt)
return NULL;
pkt->hdr.type = cpu_to_le16(info->type);
pkt->hdr.op = cpu_to_le16(info->op);
pkt->hdr.src_cid = cpu_to_le64(src_cid);
pkt->hdr.dst_cid = cpu_to_le64(dst_cid);
pkt->hdr.src_port = cpu_to_le32(src_port);
pkt->hdr.dst_port = cpu_to_le32(dst_port);
pkt->hdr.flags = cpu_to_le32(info->flags);
pkt->len = len;
pkt->hdr.len = cpu_to_le32(len);
pkt->reply = info->reply;
if (info->msg && len > 0) {
pkt->buf = kmalloc(len, GFP_KERNEL);
if (!pkt->buf)
goto out_pkt;
err = memcpy_from_msg(pkt->buf, info->msg, len);
if (err)
goto out;
}
trace_virtio_transport_alloc_pkt(src_cid, src_port,
dst_cid, dst_port,
len,
info->type,
info->op,
info->flags);
return pkt;
out:
kfree(pkt->buf);
out_pkt:
kfree(pkt);
return NULL;
}
static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
struct virtio_vsock_pkt_info *info)
{
u32 src_cid, src_port, dst_cid, dst_port;
struct virtio_vsock_sock *vvs;
struct virtio_vsock_pkt *pkt;
u32 pkt_len = info->pkt_len;
src_cid = vm_sockets_get_local_cid();
src_port = vsk->local_addr.svm_port;
if (!info->remote_cid) {
dst_cid = vsk->remote_addr.svm_cid;
dst_port = vsk->remote_addr.svm_port;
} else {
dst_cid = info->remote_cid;
dst_port = info->remote_port;
}
vvs = vsk->trans;
/* we can send less than pkt_len bytes */
if (pkt_len > VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE)
pkt_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE;
/* virtio_transport_get_credit might return less than pkt_len credit */
pkt_len = virtio_transport_get_credit(vvs, pkt_len);
/* Do not send zero length OP_RW pkt */
if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW)
return pkt_len;
pkt = virtio_transport_alloc_pkt(info, pkt_len,
src_cid, src_port,
dst_cid, dst_port);
if (!pkt) {
virtio_transport_put_credit(vvs, pkt_len);
return -ENOMEM;
}
virtio_transport_inc_tx_pkt(vvs, pkt);
return virtio_transport_get_ops()->send_pkt(pkt);
}
static void virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs,
struct virtio_vsock_pkt *pkt)
{
vvs->rx_bytes += pkt->len;
}
static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs,
struct virtio_vsock_pkt *pkt)
{
vvs->rx_bytes -= pkt->len;
vvs->fwd_cnt += pkt->len;
}
void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt)
{
spin_lock_bh(&vvs->tx_lock);
pkt->hdr.fwd_cnt = cpu_to_le32(vvs->fwd_cnt