初学rust,使用协程
Tag rust, 协程, on by view 165

最近想使用rust写个tcp透明代理转发服务,这中间涉及到socket监听以及连接处理逻辑中连接后端服务并处理连接后的相关逻辑。

监听端口

fn main() {
    let skt = Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP)).unwrap();
    let address: SocketAddr = "0.0.0.0:3333".parse().unwrap();
    skt.set_ip_transparent(true).unwrap();
    skt.set_reuse_address(true).unwrap();
    skt.set_reuse_port(true).unwrap();
    skt.bind(&address.into()).unwrap();
    skt.listen(128).unwrap();

    let listener: TcpListener = skt.into();
    println!("Server listening on port 3333");

    for stream in listener.incoming() {
        match stream {
            Ok(stream) => {
                println!("peer addr: {}", stream.peer_addr().unwrap());
                thread::spawn(move || {
                    // connection succeeded
                    handle_client(stream)
                });
            }
            Err(e) => {
                println!("Error: {}", e);
                /* connection failed */
            }
        }
    }
    // close the socket server
    drop(listener);
}

处理客户端连接

fn handle_client(mut stream: TcpStream) {
    let x = get_original_destination_addr(&stream).unwrap();
    println!("target addr: {:?}", x);

    let mut client_stream = get_rs_connect_stream_by_ipport(x.to_string().as_str());
    println!("client stream: {:?}", client_stream);

    let mut data = [0 as u8; 50]; // using 50 byte buffer
    while match stream.read(&mut data) {
        Ok(size) => {
            if size == 0 {
                false
            } else {
                // echo everything!
                println!("len: {:?}", size);
                stream.write(&data[0..size]).unwrap();
                true
            }
        }
        Err(_) => {
            println!(
                "An error occurred, terminating connection with {}",
                stream.peer_addr().unwrap()
            );
            stream.shutdown(Shutdown::Both).unwrap();
            false
        }
    } {}
}

可以看到这里每接受一个连接,就会起一个线程,如果在handle_client里再连接后端的话,就需要针对client_stream再起一个线程在线程中执行死循环,否则死循环会卡主stream这个死循环。这样一来,一个链接就得起2个线程了,那么连接一多,线程就更多了,一个进程能创建的线程数是有限的,因为线程对资源占用相对比较大,这样连接数一多,系统资源就不够用,性能就会很差。

对比到golang中的协程,我们能否在rust中使用协程呢?答案是肯定的。下面简单介绍一下rust中如何使用协程。rust中使用“协程”,我们用到tokio这个包。

对于单个异步,我们可以用async和await就可以了

#[tokio::main]
async fn main() {
    let skt = Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP)).unwrap();
    let address: SocketAddr = "0.0.0.0:3333".parse().unwrap();
    skt.set_ip_transparent(true).unwrap();
    skt.set_reuse_address(true).unwrap();
    skt.set_reuse_port(true).unwrap();
    skt.bind(&address.into()).unwrap();
    skt.listen(128).unwrap();

    let listener: TcpListener = skt.into();
    println!("Server listening on port 3333");

    for stream in listener.incoming() {
        match stream {
            Ok(stream) => {
                println!("peer addr: {}", stream.peer_addr().unwrap());
                handle_client(stream).await  // 其中 handle_client 改为了 async 异步函数
            }
            Err(e) => {
                println!("Error: {}", e);
                /* connection failed */
            }
        }
    }
    // close the socket server
    drop(listener);
}

其中 handle_client 是 async 异步函数,但是如果我想在handle_client中再起一个异步函数,直接使用async是不行的。我们可以这样处理,使用tokio::spawn就可以了,这里tokio::spawn有点类似golang中的go关键词,调用tokio::spawn的地方可以起一个异步函数。

async fn rs_handle(rs_stream: Arc<TcpStream>, stream: Arc<TcpStream>) {
    let mut client_data = [0 as u8; 50]; // using 50 byte buffer
    while match rs_stream.as_ref().read(&mut client_data) {
        Ok(client_size) => {
            if client_size == 0 {
                false
            } else {
                println!("len: {:?}", client_size);
                stream.as_ref().write(&client_data[0..client_size]).unwrap();
                true
            }
        }
        Err(_) => {
            println!("client error occurred.");
            false
        }
    } {}
}

async fn handle_client(stream: TcpStream) {
    let x = get_original_destination_addr(&stream).unwrap();
    println!("target addr: {:?}", x);

    let stream = Arc::new(stream);
    let rs_stream = Arc::new(get_rs_connect_stream_by_ipport(x.to_string().as_str()));
    println!("rs stream: {:?}", rs_stream);

    let rsh_rs_stream = rs_stream.clone();
    let rsh_stream = stream.clone();
    tokio::spawn(rs_handle(rsh_rs_stream, rsh_stream)); // 这里相当于起一个线程
    println!("rs handle setted.");

    let mut data = [0 as u8; 50]; // using 50 byte buffer
    while match stream.as_ref().read(&mut data) {
        Ok(size) => {
            if size == 0 {
                false
            } else {
                rs_stream.as_ref().write(&data[0..size]).unwrap();
                true
            }
        }
        Err(_) => {
            println!(
                "An error occurred, terminating connection with {}",
                stream.peer_addr().unwrap()
            );
            stream.shutdown(Shutdown::Both).unwrap();
            false
        }
    } {}
}

如上代码,首先需要定义一个异步函数async fn rs_handle(...)然后在调用的地方使用tokio::spawn(rs_handle(...))来调用。就可以实现同等于golang中go关键词的效果,起一个协程。