Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Feat: Support for adding Custom Domains with railway domain [domain] #571

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
305 changes: 244 additions & 61 deletions src/commands/domain.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use std::time::Duration;
use std::{cmp::max, time::Duration};

use anyhow::bail;
use colored::Colorize;
use is_terminal::IsTerminal;
use queries::domains::DomainsDomains;
use serde_json::json;

use crate::{
consts::TICK_STRING,
Expand All @@ -13,12 +14,36 @@ use crate::{

use super::*;

/// Generates a domain for a service if there is not a railway provided domain
// Checks if the user is linked to a service, if not, it will generate a domain for the default service
/// Add a custom domain or generate a railway provided domain for a service.
///
/// There is a maximum of 1 railway provided domain per service.
#[derive(Parser)]
pub struct Args {}
pub struct Args {
/// The port to connect to the domain
#[clap(short, long)]
port: Option<u16>,

pub async fn command(_args: Args, _json: bool) -> Result<()> {
/// The name of the service to generate the domain for
#[clap(short, long)]
service: Option<String>,

/// Optionally, specify a custom domain to use. If not specified, a domain will be generated.
///
/// Specifying a custom domain will also return the required DNS records
/// to add to your DNS settings
domain: Option<String>,
}

pub async fn command(args: Args, json: bool) -> Result<()> {
if let Some(domain) = args.domain {
create_custom_domain(domain, args.port, args.service, json).await?;
alexng353 marked this conversation as resolved.
Show resolved Hide resolved
} else {
create_service_domain(args.service, json).await?;
}
Ok(())
}

async fn create_service_domain(service_name: Option<String>, json: bool) -> Result<()> {
let configs = Configs::new()?;

let client = GQLClient::new_authorized(&configs)?;
Expand All @@ -28,84 +53,49 @@ pub async fn command(_args: Args, _json: bool) -> Result<()> {

let project = get_project(&client, &configs, linked_project.project.clone()).await?;

if project.services.edges.is_empty() {
return Err(RailwayError::NoServices.into());
}

// If there is only one service, it will generate a domain for that service
let service = if project.services.edges.len() == 1 {
project.services.edges[0].node.clone().id
} else {
let Some(service) = linked_project.service.clone() else {
bail!("No service linked. Run `railway service` to link to a service");
};
if project.services.edges.iter().any(|s| s.node.id == service) {
service
} else {
bail!("Service not found! Run `railway service` to link to a service");
}
};
let service = get_service(&linked_project, &project, service_name)?;

let vars = queries::domains::Variables {
project_id: linked_project.project.clone(),
environment_id: linked_project.environment.clone(),
service_id: service.clone(),
service_id: service.id.clone(),
};

let domains = post_graphql::<queries::Domains, _>(&client, configs.get_backboard(), vars)
.await?
.domains;

let domain_count = domains.service_domains.len() + domains.custom_domains.len();

if domain_count > 0 {
return print_existing_domains(&domains);
}

let spinner = (std::io::stdout().is_terminal() && !json)
.then(|| creating_domain_spiner(None))
.and_then(|s| s.ok());

let vars = mutations::service_domain_create::Variables {
service_id: service,
service_id: service.id.clone(),
environment_id: linked_project.environment.clone(),
};
let domain =
post_graphql::<mutations::ServiceDomainCreate, _>(&client, configs.get_backboard(), vars)
.await?
.service_domain_create
.domain;

if std::io::stdout().is_terminal() {
let spinner = indicatif::ProgressBar::new_spinner()
.with_style(
indicatif::ProgressStyle::default_spinner()
.tick_chars(TICK_STRING)
.template("{spinner:.green} {msg}")?,
)
.with_message("Creating domain...");
spinner.enable_steady_tick(Duration::from_millis(100));

let domain = post_graphql::<mutations::ServiceDomainCreate, _>(
&client,
configs.get_backboard(),
vars,
)
.await?
.service_domain_create
.domain;

if let Some(spinner) = spinner {
spinner.finish_and_clear();
}

let formatted_domain = format!("https://{}", domain);
println!(
"Service Domain created:\n🚀 {}",
formatted_domain.magenta().bold()
);
} else {
println!("Creating domain...");

let domain = post_graphql::<mutations::ServiceDomainCreate, _>(
&client,
configs.get_backboard(),
vars,
)
.await?
.service_domain_create
.domain;
let formatted_domain = format!("https://{}", domain);
if json {
let out = json!({
"domain": formatted_domain
});

let formatted_domain = format!("https://{}", domain);
println!("{}", serde_json::to_string_pretty(&out)?);
} else {
println!(
"Service Domain created:\n🚀 {}",
formatted_domain.magenta().bold()
Expand Down Expand Up @@ -148,3 +138,196 @@ fn print_existing_domains(domains: &DomainsDomains) -> Result<()> {

Ok(())
}

// Returns a reference to save on Heap allocations
pub fn get_service<'a>(
linked_project: &'a LinkedProject,
project: &'a queries::project::ProjectProject,
service_name: Option<String>,
) -> anyhow::Result<&'a queries::project::ProjectProjectServicesEdgesNode> {
let services = project.services.edges.iter().collect::<Vec<_>>();

if services.is_empty() {
bail!(RailwayError::NoServices);
}

if project.services.edges.len() == 1 {
return Ok(&project.services.edges[0].node);
}

if let Some(service_name) = service_name {
if let Some(service) = project
.services
.edges
.iter()
.find(|s| s.node.name == service_name)
{
return Ok(&service.node);
}

bail!(RailwayError::ServiceNotFound(service_name));
}

if let Some(service) = linked_project.service.clone() {
if project.services.edges.iter().any(|s| s.node.id == service) {
return Ok(&project
.services
.edges
.iter()
.find(|s| s.node.id == service)
.unwrap()
.node);
}
}

bail!(RailwayError::NoServices);
}

pub fn creating_domain_spiner(message: Option<String>) -> anyhow::Result<indicatif::ProgressBar> {
let spinner = indicatif::ProgressBar::new_spinner()
.with_style(
indicatif::ProgressStyle::default_spinner()
.tick_chars(TICK_STRING)
.template("{spinner:.green} {msg}")?,
)
.with_message(message.unwrap_or_else(|| "Creating domain...".to_string()));
spinner.enable_steady_tick(Duration::from_millis(100));

Ok(spinner)
}

async fn create_custom_domain(
domain: String,
port: Option<u16>,
service_name: Option<String>,
json: bool,
) -> Result<()> {
let configs = Configs::new()?;

let client = GQLClient::new_authorized(&configs)?;
let linked_project = configs.get_linked_project().await?;

ensure_project_and_environment_exist(&client, &configs, &linked_project).await?;

let project = get_project(&client, &configs, linked_project.project.clone()).await?;

let service = get_service(&linked_project, &project, service_name)?;

let spinner = (std::io::stdout().is_terminal() && !json)
.then(|| {
creating_domain_spiner(Some(format!(
"Creating custom domain for service {}{}...",
service.name,
port.map(|p| format!(" on port {}", p)).unwrap_or_default()
)))
})
.and_then(|s| s.ok());

let is_available = post_graphql::<queries::CustomDomainAvailable, _>(
&client,
configs.get_backboard(),
queries::custom_domain_available::Variables {
domain: domain.clone(),
},
)
.await?
.custom_domain_available
.available;

if !is_available {
bail!("Domain is not available:\n\t{}", domain);
}

let vars = mutations::custom_domain_create::Variables {
input: mutations::custom_domain_create::CustomDomainCreateInput {
domain: domain.clone(),
environment_id: linked_project.environment.clone(),
project_id: linked_project.project.clone(),
service_id: service.id.clone(),
target_port: port.map(|p| p as i64),
},
};

let response =
post_graphql::<mutations::CustomDomainCreate, _>(&client, configs.get_backboard(), vars)
.await?;

spinner.map(|s| s.finish_and_clear());

if json {
println!("{}", serde_json::to_string_pretty(&response)?);
return Ok(());
}

println!("Domain created: {}", response.custom_domain_create.domain);

if response.custom_domain_create.status.dns_records.is_empty() {
// This case should be impossible, but added error handling for safety.
//
// It can only occur if the backend is not returning the correct data,
// and in that case, the post_graphql call should have already errored.
bail!("No DNS records found. Please check the Railway dashboard for more information.");
}

println!(
"To finish setting up your custom domain, add the following DNS records to {}:\n",
&response.custom_domain_create.status.dns_records[0].zone
);

print_dns(response.custom_domain_create.status.dns_records);
coffee-cup marked this conversation as resolved.
Show resolved Hide resolved

println!("\nNote: if the Name is \"@\", the DNS record should be created for the root of the domain.");
println!("*DNS changes can take up to 72 hours to propagate worldwide.");

Ok(())
}

fn print_dns(
domains: Vec<
mutations::custom_domain_create::CustomDomainCreateCustomDomainCreateStatusDnsRecords,
>,
) {
// I benchmarked this iter().fold() and it's faster than using 3x iter().map()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is fine, but just for future reference, readability is much more important here than performance. Especially considering the domains length will be size 1 99% of the time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that makes sense. For n=1, the performance is entirely irrelevant.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even for n=2. The difference will be so small it is negligible. Unless thousands of domains are being processed, readability is what we should optimize for.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I change it back to three iter maps then?

let (padding_type, padding_hostlabel, padding_value) = domains
.iter()
// Minimum length should be 8, but we add 3 for extra padding so 8-3 = 5
.fold((5, 5, 5), |(max_type, max_hostlabel, max_value), d| {
(
max(max_type, d.record_type.to_string().len()),
max(max_hostlabel, d.hostlabel.len()),
max(max_value, d.required_value.len()),
)
});

// Add extra minimum padding to each length
let [padding_type, padding_hostlabel, padding_value] =
[padding_type + 3, padding_hostlabel + 3, padding_value + 3];

// Print the header with consistent padding
println!(
"\t{:<width_type$}{:<width_host$}{:<width_value$}",
"Type",
"Name",
"Value",
width_type = padding_type,
width_host = padding_hostlabel,
width_value = padding_value
);

// Print each domain entry with the same padding
for domain in &domains {
println!(
"\t{:<width_type$}{:<width_host$}{:<width_value$}",
domain.record_type.to_string(),
if domain.hostlabel.is_empty() {
"@"
} else {
&domain.hostlabel
},
domain.required_value,
width_type = padding_type,
width_host = padding_hostlabel,
width_value = padding_value
);
}
}
25 changes: 25 additions & 0 deletions src/gql/mutations/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use graphql_client::GraphQLQuery;
use serde::{Deserialize, Serialize};
type EnvironmentVariables = std::collections::BTreeMap<String, String>;
use chrono::{DateTime as DateTimeType, Utc};

pub type DateTime = DateTimeType<Utc>;

#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "camelCase")]
Expand Down Expand Up @@ -148,3 +151,25 @@ pub struct VariableCollectionUpsert;
skip_serializing_none
)]
pub struct ServiceCreate;

#[derive(GraphQLQuery)]
#[graphql(
schema_path = "src/gql/schema.json",
query_path = "src/gql/mutations/strings/CustomDomainCreate.graphql",
response_derives = "Debug, Serialize, Clone",
skip_serializing_none
)]
pub struct CustomDomainCreate;

impl std::fmt::Display for custom_domain_create::DNSRecordType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::DNS_RECORD_TYPE_CNAME => write!(f, "CNAME"),
Self::DNS_RECORD_TYPE_A => write!(f, "A"),
Self::DNS_RECORD_TYPE_NS => write!(f, "NS"),
Self::DNS_RECORD_TYPE_UNSPECIFIED => write!(f, "UNSPECIFIED"),
Self::UNRECOGNIZED => write!(f, "UNRECOGNIZED"),
Self::Other(s) => write!(f, "{}", s),
}
}
}
Loading