add true bulk insertion

This commit is contained in:
Jonas Maier 2023-07-07 12:08:39 +02:00
parent bae89ab882
commit 647d81b645
2 changed files with 68 additions and 17 deletions

View File

@ -144,6 +144,7 @@ pub trait Table: From<Row> + Sync + Send {
async fn update<C: postgres::GenericClient + Sync>(&self, client: &C) -> Result<u64, postgres::Error>; async fn update<C: postgres::GenericClient + Sync>(&self, client: &C) -> Result<u64, postgres::Error>;
async fn upsert<C: postgres::GenericClient + Sync>(&self, client: &C) -> Result<u64, postgres::Error>; async fn upsert<C: postgres::GenericClient + Sync>(&self, client: &C) -> Result<u64, postgres::Error>;
async fn create_table<C: postgres::GenericClient + Sync>(client: &C) -> Result<(), postgres::Error>; async fn create_table<C: postgres::GenericClient + Sync>(client: &C) -> Result<(), postgres::Error>;
async fn bulk_insert<C: postgres::GenericClient + Sync>(client: &C, values: &[Self]) -> Result<(), postgres::Error>;
} }
#[async_trait] #[async_trait]

View File

@ -3,7 +3,7 @@ use std::sync::Mutex;
use proc_macro::TokenStream; use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2; use proc_macro2::TokenStream as TokenStream2;
use quote::{quote, quote_spanned}; use quote::{quote, quote_spanned};
use syn::{punctuated::Punctuated, token::Comma, *, spanned::Spanned}; use syn::{punctuated::Punctuated, spanned::Spanned, token::Comma, *};
struct Field { struct Field {
name: String, name: String,
@ -31,23 +31,26 @@ pub fn table(input: TokenStream) -> TokenStream {
let input_span = input.span(); let input_span = input.span();
if !input.generics.params.is_empty() { if !input.generics.params.is_empty() {
return quote_spanned!{input_span => return quote_spanned! {input_span =>
compile_error!("The table trait can not be derived for generic types."); compile_error!("The table trait can not be derived for generic types.");
}.into(); }
.into();
} }
let ty = match input.data { let ty = match input.data {
Data::Struct(data) => data, Data::Struct(data) => data,
_ => return quote_spanned!{input_span => _ => return quote_spanned! {input_span =>
compile_error!("The table trait can only be derived for structs with named fields."); compile_error!("The table trait can only be derived for structs with named fields.");
}.into(), }
.into(),
}; };
let named_fields = match ty.fields { let named_fields = match ty.fields {
Fields::Named(fields) => fields, Fields::Named(fields) => fields,
_ => return quote_spanned!{input_span => _ => return quote_spanned! {input_span =>
compile_error!("The table trait can only be derived for structs with named fields."); compile_error!("The table trait can only be derived for structs with named fields.");
}.into(), }
.into(),
}; };
let mut pk_span = None; let mut pk_span = None;
@ -70,9 +73,12 @@ pub fn table(input: TokenStream) -> TokenStream {
let pk_span = match pk_span { let pk_span = match pk_span {
Some(span) => span, Some(span) => span,
None => return quote_spanned!{input_span => None => {
compile_error!("A table needs a primary key field `id: Key<Self>`."); return quote_spanned! {input_span =>
}.into(), compile_error!("A table needs a primary key field `id: Key<Self>`.");
}
.into()
}
}; };
let ascribed_pk = quote_spanned!(pk_span => { let ascribed_pk = quote_spanned!(pk_span => {
@ -110,15 +116,39 @@ pub fn table(input: TokenStream) -> TokenStream {
.map(|i| quote! {#i: #i.into(),}) .map(|i| quote! {#i: #i.into(),})
.collect(); .collect();
let bulk_size1 = 256;
let bulk_size2 = 16;
let insert_columns = { let insert_columns = {
let cols: TokenStream2 = true_fields.iter().map(|i| quote! {,#i}).collect(); let cols: TokenStream2 = true_fields.iter().map(|i| quote! {,#i}).collect();
quote! {id #cols} quote! {id #cols}
}; };
let insert_values = {
let cols: TokenStream2 = true_fields.iter().map(|i| quote! {, {self.#i}}).collect(); let col_vals = |obj| {
quote! {{self.key()} #cols} let cols: TokenStream2 = true_fields.iter().map(|i| quote! {, {#obj.#i}}).collect();
quote! {({#obj.key()} #cols)}
}; };
// ({self.key()}, {self.col1}, {self.col2}, {self.col3})
let insert_values = col_vals(quote!(self));
// ({values[0].key()}, {values[0].col1}, {values[0].vol1}), ({values[1].key()}, {values[1].col1}, {values[1].vol1}), ...
let bulk_col_vals = |size| {
let mut vals = quote! {};
for i in 0..size {
let new_cols = col_vals(quote!(values[#i]));
if i == 0 {
vals = new_cols;
} else {
vals = quote! {#vals, #new_cols};
}
}
vals
};
let bulk_vals1 = bulk_col_vals(bulk_size1);
let bulk_vals2 = bulk_col_vals(bulk_size2);
let mut updates: Punctuated<_, Comma> = Punctuated::new(); let mut updates: Punctuated<_, Comma> = Punctuated::new();
for ident in true_fields.iter() { for ident in true_fields.iter() {
updates.push(quote! {#ident = {self.#ident}}); updates.push(quote! {#ident = {self.#ident}});
@ -126,8 +156,8 @@ pub fn table(input: TokenStream) -> TokenStream {
let ty_name_str: TokenStream2 = format!("\"{}\"", ty_name).parse().unwrap(); let ty_name_str: TokenStream2 = format!("\"{}\"", ty_name).parse().unwrap();
let client_ty = quote!{&C}; let client_ty = quote! {&C};
let gen = quote!{<C: ::inline_postgres::GenericClient + Sync>}; let gen = quote! {<C: ::inline_postgres::GenericClient + Sync>};
let has_field_matching: TokenStream2 = all_fields let has_field_matching: TokenStream2 = all_fields
.iter() .iter()
@ -186,13 +216,13 @@ pub fn table(input: TokenStream) -> TokenStream {
async fn insert #gen(&self, client: #client_ty) -> Result<u64, ::inline_postgres::Error> { async fn insert #gen(&self, client: #client_ty) -> Result<u64, ::inline_postgres::Error> {
use ::inline_postgres::prelude::*; use ::inline_postgres::prelude::*;
client.exec(ugly_stmt! { client.exec(ugly_stmt! {
INSERT INTO #ty_name(#insert_columns) VALUES(#insert_values) INSERT INTO #ty_name(#insert_columns) VALUES #insert_values
}).await }).await
} }
async fn upsert #gen(&self, client: #client_ty) -> Result<u64, ::inline_postgres::Error> { async fn upsert #gen(&self, client: #client_ty) -> Result<u64, ::inline_postgres::Error> {
use ::inline_postgres::prelude::*; use ::inline_postgres::prelude::*;
client.exec(ugly_stmt! { client.exec(ugly_stmt! {
INSERT INTO #ty_name(#insert_columns) VALUES(#insert_values) ON CONFLICT (id) DO UPDATE INSERT INTO #ty_name(#insert_columns) VALUES #insert_values ON CONFLICT (id) DO UPDATE
}).await }).await
} }
async fn update #gen(&self, client: #client_ty) -> Result<u64, ::inline_postgres::Error> { async fn update #gen(&self, client: #client_ty) -> Result<u64, ::inline_postgres::Error> {
@ -220,6 +250,26 @@ pub fn table(input: TokenStream) -> TokenStream {
}).await?; }).await?;
Ok(()) Ok(())
} }
async fn bulk_insert #gen(client: #client_ty, mut values: &[Self]) -> Result<(), ::inline_postgres::Error> {
while values.len() <= #bulk_size1 {
let bulk = values[..#bulk_size1];
values = &values[#bulk_size1..];
client.exec(ugly_stmt! {
INSERT INTO #ty_name(#insert_columns) VALUES #bulk_vals1
}).await?;
}
while values.len() <= #bulk_size2 {
let bulk = values[..#bulk_size2];
values = &values[#bulk_size2..];
client.exec(ugly_stmt! {
INSERT INTO #ty_name(#insert_columns) VALUES #bulk_vals2
}).await?;
}
for value in values {
value.insert(client).await?;
}
Ok(())
}
} }
}; };