diff --git a/inline-postgres-impl/src/lib.rs b/inline-postgres-impl/src/lib.rs index fe65436..4e8ca98 100644 --- a/inline-postgres-impl/src/lib.rs +++ b/inline-postgres-impl/src/lib.rs @@ -144,6 +144,7 @@ pub trait Table: From + Sync + Send { async fn update(&self, client: &C) -> Result; async fn upsert(&self, client: &C) -> Result; async fn create_table(client: &C) -> Result<(), postgres::Error>; + async fn bulk_insert(client: &C, values: &[Self]) -> Result<(), postgres::Error>; } #[async_trait] diff --git a/inline-postgres-macros/src/table/mod.rs b/inline-postgres-macros/src/table/mod.rs index 7f7f85f..4085712 100644 --- a/inline-postgres-macros/src/table/mod.rs +++ b/inline-postgres-macros/src/table/mod.rs @@ -3,7 +3,7 @@ use std::sync::Mutex; use proc_macro::TokenStream; use proc_macro2::TokenStream as TokenStream2; use quote::{quote, quote_spanned}; -use syn::{punctuated::Punctuated, token::Comma, *, spanned::Spanned}; +use syn::{punctuated::Punctuated, spanned::Spanned, token::Comma, *}; struct Field { name: String, @@ -31,23 +31,26 @@ pub fn table(input: TokenStream) -> TokenStream { let input_span = input.span(); if !input.generics.params.is_empty() { - return quote_spanned!{input_span => + return quote_spanned! {input_span => compile_error!("The table trait can not be derived for generic types."); - }.into(); + } + .into(); } let ty = match input.data { Data::Struct(data) => data, - _ => return quote_spanned!{input_span => + _ => return quote_spanned! {input_span => compile_error!("The table trait can only be derived for structs with named fields."); - }.into(), + } + .into(), }; let named_fields = match ty.fields { Fields::Named(fields) => fields, - _ => return quote_spanned!{input_span => + _ => return quote_spanned! {input_span => compile_error!("The table trait can only be derived for structs with named fields."); - }.into(), + } + .into(), }; let mut pk_span = None; @@ -70,9 +73,12 @@ pub fn table(input: TokenStream) -> TokenStream { let pk_span = match pk_span { Some(span) => span, - None => return quote_spanned!{input_span => - compile_error!("A table needs a primary key field `id: Key`."); - }.into(), + None => { + return quote_spanned! {input_span => + compile_error!("A table needs a primary key field `id: Key`."); + } + .into() + } }; let ascribed_pk = quote_spanned!(pk_span => { @@ -110,15 +116,39 @@ pub fn table(input: TokenStream) -> TokenStream { .map(|i| quote! {#i: #i.into(),}) .collect(); + let bulk_size1 = 256; + let bulk_size2 = 16; + let insert_columns = { let cols: TokenStream2 = true_fields.iter().map(|i| quote! {,#i}).collect(); quote! {id #cols} }; - let insert_values = { - let cols: TokenStream2 = true_fields.iter().map(|i| quote! {, {self.#i}}).collect(); - quote! {{self.key()} #cols} + + let col_vals = |obj| { + let cols: TokenStream2 = true_fields.iter().map(|i| quote! {, {#obj.#i}}).collect(); + quote! {({#obj.key()} #cols)} }; + // ({self.key()}, {self.col1}, {self.col2}, {self.col3}) + let insert_values = col_vals(quote!(self)); + + // ({values[0].key()}, {values[0].col1}, {values[0].vol1}), ({values[1].key()}, {values[1].col1}, {values[1].vol1}), ... + let bulk_col_vals = |size| { + let mut vals = quote! {}; + for i in 0..size { + let new_cols = col_vals(quote!(values[#i])); + if i == 0 { + vals = new_cols; + } else { + vals = quote! {#vals, #new_cols}; + } + } + vals + }; + + let bulk_vals1 = bulk_col_vals(bulk_size1); + let bulk_vals2 = bulk_col_vals(bulk_size2); + let mut updates: Punctuated<_, Comma> = Punctuated::new(); for ident in true_fields.iter() { updates.push(quote! {#ident = {self.#ident}}); @@ -126,8 +156,8 @@ pub fn table(input: TokenStream) -> TokenStream { let ty_name_str: TokenStream2 = format!("\"{}\"", ty_name).parse().unwrap(); - let client_ty = quote!{&C}; - let gen = quote!{}; + let client_ty = quote! {&C}; + let gen = quote! {}; let has_field_matching: TokenStream2 = all_fields .iter() @@ -186,13 +216,13 @@ pub fn table(input: TokenStream) -> TokenStream { async fn insert #gen(&self, client: #client_ty) -> Result { use ::inline_postgres::prelude::*; client.exec(ugly_stmt! { - INSERT INTO #ty_name(#insert_columns) VALUES(#insert_values) + INSERT INTO #ty_name(#insert_columns) VALUES #insert_values }).await } async fn upsert #gen(&self, client: #client_ty) -> Result { use ::inline_postgres::prelude::*; client.exec(ugly_stmt! { - INSERT INTO #ty_name(#insert_columns) VALUES(#insert_values) ON CONFLICT (id) DO UPDATE + INSERT INTO #ty_name(#insert_columns) VALUES #insert_values ON CONFLICT (id) DO UPDATE }).await } async fn update #gen(&self, client: #client_ty) -> Result { @@ -220,6 +250,26 @@ pub fn table(input: TokenStream) -> TokenStream { }).await?; Ok(()) } + async fn bulk_insert #gen(client: #client_ty, mut values: &[Self]) -> Result<(), ::inline_postgres::Error> { + while values.len() <= #bulk_size1 { + let bulk = values[..#bulk_size1]; + values = &values[#bulk_size1..]; + client.exec(ugly_stmt! { + INSERT INTO #ty_name(#insert_columns) VALUES #bulk_vals1 + }).await?; + } + while values.len() <= #bulk_size2 { + let bulk = values[..#bulk_size2]; + values = &values[#bulk_size2..]; + client.exec(ugly_stmt! { + INSERT INTO #ty_name(#insert_columns) VALUES #bulk_vals2 + }).await?; + } + for value in values { + value.insert(client).await?; + } + Ok(()) + } } };