//! This file defines the endpoints for the ML API for importing and exporting SurrealML models.
use crate::dbs::DB;
use crate::err::Error;
use crate::net::output;
use axum::extract::{BodyStream, DefaultBodyLimit, Path};
use axum::response::IntoResponse;
use axum::response::Response;
use axum::routing::{get, post};
use axum::Extension;
use axum::Router;
use bytes::Bytes;
use futures_util::StreamExt;
use http::StatusCode;
use http_body::Body as HttpBody;
use hyper::body::Body;
use surrealdb::dbs::Session;
use surrealdb::iam::check::check_ns_db;
use surrealdb::iam::Action::{Edit, View};
use surrealdb::iam::ResourceKind::Model;
use surrealdb::kvs::{LockType::Optimistic, TransactionType::Read};
use surrealdb::sql::statements::{DefineModelStatement, DefineStatement};
use surrealml_core::storage::surml_file::SurMlFile;
use tower_http::limit::RequestBodyLimitLayer;

const MAX: usize = 1024 * 1024 * 1024 * 4; // 4 GiB

/// The router definition for the ML API endpoints.
pub(super) fn router<S, B>() -> Router<S, B>
where
	B: HttpBody + Send + 'static,
	B::Data: Send + Into<Bytes>,
	B::Error: std::error::Error + Send + Sync + 'static,
	S: Clone + Send + Sync + 'static,
{
	Router::new()
		.route("/ml/import", post(import))
		.route("/ml/export/:name/:version", get(export))
		.route_layer(DefaultBodyLimit::disable())
		.layer(RequestBodyLimitLayer::new(MAX))
}

/// This endpoint allows the user to import a model into the database.
async fn import(
	Extension(session): Extension<Session>,
	mut stream: BodyStream,
) -> Result<impl IntoResponse, impl IntoResponse> {
	// Get the datastore reference
	let db = DB.get().unwrap();
	// Ensure a NS and DB are set
	let (nsv, dbv) = check_ns_db(&session)?;
	// Check the permissions level
	db.check(&session, Edit, Model.on_db(&nsv, &dbv))?;
	// Create a new buffer
	let mut buffer = Vec::new();
	// Load all the uploaded file chunks
	while let Some(chunk) = stream.next().await {
		buffer.extend_from_slice(&chunk?);
	}
	// Check that the SurrealML file is valid
	let file = match SurMlFile::from_bytes(buffer) {
		Ok(file) => file,
		Err(err) => return Err(Error::Other(err.to_string())),
	};
	// Convert the file back in to raw bytes
	let data = file.to_bytes();
	// Calculate the hash of the model file
	let hash = surrealdb::obs::hash(&data);
	// Calculate the path of the model file
	let path = format!(
		"ml/{nsv}/{dbv}/{}-{}-{hash}.surml",
		file.header.name.to_string(),
		file.header.version.to_string()
	);
	// Insert the file data in to the store
	surrealdb::obs::put(&path, data).await?;
	// Insert the model in to the database
	db.process(
		DefineStatement::Model(DefineModelStatement {
			hash,
			name: file.header.name.to_string().into(),
			version: file.header.version.to_string(),
			comment: Some(file.header.description.to_string().into()),
			..Default::default()
		})
		.into(),
		&session,
		None,
	)
	.await?;
	//
	Ok(output::none())
}

/// This endpoint allows the user to export a model from the database.
async fn export(
	Extension(session): Extension<Session>,
	Path((name, version)): Path<(String, String)>,
) -> Result<impl IntoResponse, Error> {
	// Get the datastore reference
	let db = DB.get().unwrap();
	// Ensure a NS and DB are set
	let (nsv, dbv) = check_ns_db(&session)?;
	// Check the permissions level
	db.check(&session, View, Model.on_db(&nsv, &dbv))?;
	// Start a new readonly transaction
	let mut tx = db.transaction(Read, Optimistic).await?;
	// Attempt to get the model definition
	let info = tx.get_db_model(&nsv, &dbv, &name, &version).await?;
	// Calculate the path of the model file
	let path = format!("ml/{nsv}/{dbv}/{name}-{version}-{}.surml", info.hash);
	// Export the file data in to the store
	let mut data = surrealdb::obs::stream(path).await?;
	// Create a chunked response
	let (mut chn, body) = Body::channel();
	// Process all stream values
	tokio::spawn(async move {
		while let Some(Ok(v)) = data.next().await {
			let _ = chn.send_data(v).await;
		}
	});
	// Return the streamed body
	Ok(Response::builder().status(StatusCode::OK).body(body).unwrap())
}