ps/Modules/Alkami.Ops.Certificates/Cmdlets/ImportCertificatesFromSecretServer.cs
2023-05-30 22:51:22 -07:00

487 lines
23 KiB
C#

using Alkami.Ops.Certificates.Data;
using Alkami.Ops.Certificates.SecretServer;
using Alkami.Ops.Certificates.SecretServer.Models;
using Alkami.Ops.Certificates.Utilities;
using Alkami.Ops.Common.Cryptography;
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Management.Automation;
using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;
using System.ServiceProcess;
namespace Alkami.Ops.Certificates
{
/// <summary>
/// Downloads certificates for the current server from the appropriate Secret Server MachineSecrets folder if it exists.
/// </summary>
[Cmdlet("Import", "CertificatesFromSecretServer")]
[OutputType(typeof(string))]
public class ImportCertificatesFromSecretServer : Cmdlet
{
[Parameter(Position = 0, Mandatory = true)]
public string SecretUsername { get; set; }
[Parameter(Position = 1, Mandatory = true)]
public string SecretPassword { get; set; }
[Parameter(Position = 2, Mandatory = false)]
public string GrantUserGmsaPrefix { get; set; }
[Parameter(Position = 3, Mandatory = false)]
public string SecretSite { get; set; } = "https://alkami.secretservercloud.com";
[Parameter(Position = 4, Mandatory = false)]
public string MachineSecretFolder { get; set; } = "ops.deployment-CertApi/MachineSecrets";
[Parameter(Position = 5, Mandatory = false)]
public string thumbprintsFilePath { get; set; } = @"C:\Tools\CertificateManagement\TrackedThumbprints\";
private readonly string[] storeTypes = new string[] { "personal", "ia", "root", "trustedpeople" };
private string[] extensions = new string[] { ".pfx", ".cer" };
/// <summary>
/// Entry point method.
/// </summary>
protected override void ProcessRecord()
{
string downloadPathTempPath = Path.Combine(Path.GetTempPath(), "CertificateImport");
try
{
// Create a temp path to stage certs
if (Directory.Exists(downloadPathTempPath))
{
Extensions.ClearDirectory(downloadPathTempPath);
}
else
{
Directory.CreateDirectory(downloadPathTempPath);
}
// Create a folder to record installed, but untracked certs.
if (!Directory.Exists(this.thumbprintsFilePath))
{
Directory.CreateDirectory(this.thumbprintsFilePath);
}
ImportCertificates(downloadPathTempPath, this.GrantUserGmsaPrefix, this.MachineSecretFolder, this.SecretPassword, this.SecretSite, this.SecretUsername);
}
finally
{
if (Directory.Exists(downloadPathTempPath))
{
Directory.Delete(downloadPathTempPath, true);
}
}
}
/// <summary>
/// Imports certificates to the machine from the MachineSecrets folder.
/// </summary>
/// <param name="downloadPath">Path where secret zips will be stored temporarily</param>
/// <param name="gmsaPrefix">Pod prefix for the GMSA accounts</param>
/// <param name="machineSecretFolder">Name of the folder in Secret Server from which we're downloading secrets</param>
/// <param name="secretPassword">Password used to authenticate with Secret Server</param>
/// <param name="secretSite">Secret Server URI</param>
/// <param name="secretUsername">Username used to authenticate with Secret Server</param>
private void ImportCertificates(string downloadPath, string gmsaPrefix, string machineSecretFolder, string secretPassword, string secretSite, string secretUsername)
{
Console.WriteLine("Importing certificates to server from the secret server.");
// Get the environment properties of the current server.
var serverInfo = Extensions.GetServerInfo("localhost");
if (serverInfo == null)
{
Console.WriteLine("Environment properties could not be successfully read from the machine config. Exiting...");
return;
}
// Create secret server client.
using (var client = new SecretServerClient(secretSite, secretUsername, secretPassword))
{
List<Secret> preliminarySecretList = GetSecretsForServer(machineSecretFolder, serverInfo, client);
// Load file hashes from all files in thumbprintsFilePath excepting untracked.json
var secretInfoFiles = Directory.GetFiles(thumbprintsFilePath).Where(s => s.IndexOf("untracked.json") == -1).Select(s => s);
var zipHashes = new List<string>();
foreach (var file in secretInfoFiles)
{
var fileJson = File.ReadAllText(file);
var fileData = JsonConvert.DeserializeObject<SecretZipInfo>(fileJson);
zipHashes.Add(fileData.FileHash);
}
var unchangedSecretIds = new List<int>();
foreach (var secret in preliminarySecretList)
{
foreach (var hash in zipHashes) {
if (!client.DetectChanges(secret.ID, hash).Result)
{
unchangedSecretIds.Add(secret.ID);
Console.WriteLine($"Zip file has not changed for secret {secret.Name} from folder ID {secret.FolderID}. Skipping download.");
}
}
if (!unchangedSecretIds.Contains(secret.ID))
{
Console.WriteLine($"Downloading secret '{secret.Name}' with ID {secret.ID} from folder ID {secret.FolderID}");
}
}
var finalSecretList = preliminarySecretList.Where(s => !unchangedSecretIds.Contains(s.ID));
// If there are no secrets to download, return!
if (finalSecretList == null && !finalSecretList.Any())
{
Console.WriteLine("Could not locate any certificate secrets to download and install. Exiting...");
return;
}
// Write debug info about which secrets are being downloaded.
WriteVerbose($"Found {finalSecretList.Count()} secrets to download.");
// Fetch the full secret. GetSecretsByFolder only fetches the top-level secret info, and not the data in the secret.
finalSecretList = finalSecretList.Select(secret => client.GetSecretByID(secret.ID).GetAwaiter().GetResult()).ToList();
// Download all the machine secrets, and install them
var certificateZipsAndPasswords = new List<(string zipPath, string importPassword, string secretName)>();
foreach (var secret in finalSecretList)
{
var zipFilePath = Path.Combine(downloadPath, $"{Guid.NewGuid()}.zip");
var downloadSuccessful = client.DownloadFile(zipFilePath, secret).GetAwaiter().GetResult();
if (downloadSuccessful)
{
certificateZipsAndPasswords.Add((zipFilePath, secret["Import Password"], secret.Name));
}
else
{
throw new Exception($"Failed to download secret {secret.Name}");
}
}
// Determine the users to grant rights to the certificates to.
var grantUsers = GetUsersToGrantRightsTo(serverInfo.MicroUser, serverInfo.DatabaseUser, gmsaPrefix);
// Unzip all of the cert zips, import the certs, and track which ones we just installed.
var unzipPath = Path.Combine(downloadPath, "Certificates");
var thumbprintsFromSecret = new List<string>();
foreach (var zipAndPassword in certificateZipsAndPasswords)
{
var unzipOutputDirectory = Path.Combine(unzipPath, Path.GetFileNameWithoutExtension(zipAndPassword.zipPath));
var zipHash = "";
using (var fileStream = File.OpenRead(zipAndPassword.zipPath))
{
zipHash = Extensions.GetMd5HashString(fileStream);
}
System.IO.Compression.ZipFile.ExtractToDirectory(zipAndPassword.Item1, unzipOutputDirectory);
ImportCertificatesToLocalMachine(serverInfo, unzipOutputDirectory, grantUsers, zipAndPassword.importPassword);
// Get list of new tracked certs and append to a total list of all tracked certs
var certificateFiles = Directory.GetFiles(unzipOutputDirectory, "*", SearchOption.AllDirectories);
foreach (string certificateFile in certificateFiles)
{
var tempCertificate = new X509Certificate2(certificateFile, zipAndPassword.importPassword);
thumbprintsFromSecret.Add(tempCertificate.Thumbprint);
tempCertificate.Dispose();
}
var secretZipInfo = new SecretZipInfo() { CertificateThumbprints = thumbprintsFromSecret, FileHash = zipHash };
// Write certs pulled from Secret to file
string thumprintsFileName = this.thumbprintsFilePath + zipAndPassword.secretName + ".json";
File.WriteAllText(thumprintsFileName, JsonConvert.SerializeObject(secretZipInfo));
}
// Compare local certs against list of all tracked certs.
TrackUnregisteredCerts(thumbprintsFromSecret);
Console.WriteLine("Certificate import complete.");
}
}
private List<Secret> GetSecretsForServer(string machineSecretFolder, ServerInfo serverInfo, SecretServerClient client)
{
// Determine the base folder for the environment type, the common folder for that environment, and the folder for the specific environment.
var baseFolderPath = Path.Combine(machineSecretFolder, serverInfo.EnvironmentType);
var commonFolderPath = Path.Combine(baseFolderPath, "Common");
var environmentFolderPath = Path.Combine(baseFolderPath, serverInfo.EnvironmentName);
var commonFolder = client.GetFolder(commonFolderPath);
var environmentFolder = client.GetFolder(environmentFolderPath);
// Determine which secrets we need to download from the machine folders.
var desiredServerTypes = new string[] { "all", serverInfo.ServerType.ToLower() };
WriteVerbose($"Determined that ServerType is {serverInfo.ServerType}.");
// Download the 4 secrets relevant to this environment if the folders/secrets exist.
// EnvironmentName / (Web|App)&(All)
// Common / (Web|App)&(All)
var secretsToDownload = new List<Secret>();
if (commonFolder != null)
{
var commonSecrets = client.GetSecretsByFolder(commonFolder);
foreach (var secret in commonSecrets)
{
// all secrets must have unique names, this sanitizes back to normal names
secret.Name = secret.Name.Split('-')[0];
}
secretsToDownload = commonSecrets.Where(secret => desiredServerTypes.Contains(secret.Name.ToLower())).ToList();
}
// Check for pod specific secrets. If they exist, concat. Otherwise, just use common as pulled above.
if (environmentFolder != null)
{
var secrets = client.GetSecretsByFolder(environmentFolder);
foreach (var secret in secrets)
{
secret.Name = secret.Name.Split('-')[0];
}
var envSecretsToDownload = secrets.Where(secret => desiredServerTypes.Contains(secret.Name.ToLower())).ToList();
if (!secretsToDownload.Any())
{
secretsToDownload = envSecretsToDownload;
}
else
{
secretsToDownload.Concat(envSecretsToDownload);
}
}
return secretsToDownload;
}
/// <summary>
/// Take a list of thumbprints and compare it with those currently installed on the machine where this is run.
/// </summary>
/// <param name="trackedThumbprints">List of certificate thumbprints.</param>
private void TrackUnregisteredCerts(List<string> trackedThumbprints)
{
var localCerts = new Dictionary<string, List<string>>();
// Get cert thumbprints from local store
foreach (var storeName in this.storeTypes)
{
var certStore = Extensions.GetStoreNameByFolderName(storeName);
localCerts.Add(storeName, CertificateHelper.GetAllCertificates(certStore, StoreLocation.LocalMachine).ToList().Select(c => c.Thumbprint).ToList());
}
// Compare 2 lists of certs (from secret, and from local store).
var allLocalCerts = localCerts.SelectMany(c => c.Value);
var allUntrackedCerts = CompareCertThumbprints(allLocalCerts, trackedThumbprints);
// Write untracked certs to file
File.WriteAllText(this.thumbprintsFilePath + "untracked.json", JsonConvert.SerializeObject(allUntrackedCerts));
}
/// <summary>
/// Takes two lists of thumbprints and returns a list of untracked certificats and when they were found.
/// </summary>
/// <param name="localCerts">Collection of certificates from local store</param>
/// <param name="managedCertThumbprints">Collection of known managed certificates</param>
/// <returns>Dictionary of all untracked thumbprints paired with the first time they were found.</returns>
private Dictionary<string, DateTime> CompareCertThumbprints(IEnumerable<string> localCerts, IEnumerable<string> managedCertThumbprints)
{
var untrackedCerts = localCerts.Except(managedCertThumbprints);
var untrackedJsonFilePath = this.thumbprintsFilePath + "untracked.json";
var returnCerts = new Dictionary<string, DateTime>();
if (File.Exists(untrackedJsonFilePath))
{
var legacyUntrackedJsonFile = File.ReadAllText(untrackedJsonFilePath);
var legacyUntrackedCerts = JsonConvert.DeserializeObject<Dictionary<string, DateTime>>(legacyUntrackedJsonFile);
var legacyUntrackedThumbprints = legacyUntrackedCerts.Select(u => u.Key);
var newUntrackedCerts = untrackedCerts.Except(legacyUntrackedThumbprints).ToDictionary(c => c, c => DateTime.Now);
foreach (var cert in newUntrackedCerts)
{
// This should never fail, because we're only adding certs we didn't find above.
if (!legacyUntrackedCerts.TryAdd(cert.Key, cert.Value))
{
WriteWarning("Warning: Somehow found duplicate untracked certs with thumbprint " + cert.Key + " when writing to a file. WTF?");
}
}
returnCerts = legacyUntrackedCerts;
}
else
{
// Create untracked file with all untracked certs. All timestamps should be DateTime.now
returnCerts = untrackedCerts.ToDictionary(c => c, c => DateTime.Now);
}
return returnCerts;
}
/// <summary>
/// Imports a certificate directory with standard ia/personal/root/trustedpeople folders into the local machine.
/// </summary>
/// <param name="server"></param>
/// <param name="importFolder"></param>
private void ImportCertificatesToLocalMachine(ServerInfo server, string importFolder, string[] grantUsers, string password)
{
WriteVerbose($"Importing certificates to '{importFolder}' with rights granted to users '{string.Join(",", grantUsers)}'");
foreach (var store in this.storeTypes)
{
// Determine the folder of certs to import for the store type.
var folderPath = Path.Combine(importFolder, store);
if (!Directory.Exists(folderPath))
{
continue;
}
// Grab all of the certs.
var certificates = Directory.GetFiles(folderPath)
.Where(file => extensions.Contains(Path.GetExtension(file).ToLower())); // Filter to .pfx, and .cer
// Move on if there are no certificates to install for this store type.
if (!certificates.Any())
{
continue;
}
WriteVerbose($"Importing certificates from '{folderPath}' into the {store} store.");
// Import certificates into the appropriate store.
StoreName storeName = Extensions.GetStoreNameByFolderName(store);
foreach (var certificatePath in certificates)
{
WriteVerbose($"Importing certificate {certificatePath}");
// Load the certificate if it isn't already in the store.
X509Certificate2 certificate = new X509Certificate2(certificatePath, password);
string certName = certificate.GetNameInfo(X509NameType.SimpleName, false);
string thumbprint = certificate.Thumbprint;
certificate.Dispose();
// See if the cert is already on the local machine.
certificate = CertificateHelper.FindCertificateByThumbprint(thumbprint, storeName, StoreLocation.LocalMachine, "localhost");
// Only load the cert if it isn't on the local machine.
if (certificate == null)
{
CertificateHelper.LoadCertificateToStore(certificatePath, storeName, StoreLocation.LocalMachine, password);
certificate = CertificateHelper.FindCertificateByThumbprint(thumbprint, storeName, StoreLocation.LocalMachine, "localhost");
}
// Grant user rights, only if it's a pfx.
if (string.Equals(Path.GetExtension(certificatePath), ".pfx", StringComparison.OrdinalIgnoreCase))
{
GrantRightsToCertificate(certificate, storeName, grantUsers);
}
}
}
}
/// <summary>
/// Grants user access rights to the specified certificate.
/// </summary>
/// <param name="certificateFilename"></param>
/// <param name="users"></param>
private void GrantRightsToCertificate(X509Certificate2 certificate, StoreName storeName, string[] users)
{
// Look for the unique name of the cert, so we can track down the file to set ACL's
string uniqueContainerName = null;
using (var rsa = certificate.GetRSAPrivateKey())
{
RSACng rsaCng = (RSACng)rsa;
using (CngKey key = rsaCng.Key)
{
uniqueContainerName = key.UniqueName;
}
}
// Gather properties from the cert.
string certName = certificate.GetNameInfo(X509NameType.SimpleName, false);
string thumbprint = certificate.Thumbprint;
// Locate the private key in the registry.
var keyFilePath = CertificateHelper.FindKeyLocation(uniqueContainerName);
var pkFile = new FileInfo(Path.Combine(keyFilePath, uniqueContainerName));
// Grant user permissions to the certificate.
WriteVerbose($"Granting access to {certificate}:{thumbprint} to users {string.Join(",", users)}");
foreach (var user in users)
{
Common.Cryptography.CertificateHelper.GrantRightsToPrivateKeys(pkFile, user);
}
}
private List<string> GetAllThumbprintsFromStore(StoreName name)
{
var certs = CertificateHelper.GetAllCertificates(name, StoreLocation.LocalMachine);
return certs.ToList().Select(c => c.Thumbprint).ToList();
}
/// <summary>
/// Gets users to grant rights to on certificates. Returns IIS_IUSRS, microservice users, and nag/radium users.
/// </summary>
/// <param name="microserviceUser"></param>
/// <param name="databaseUser"></param>
/// <param name="gmsaPrefix"></param>
/// <returns></returns>
private string[] GetUsersToGrantRightsTo(string microserviceUser, string databaseUser, string gmsaPrefix)
{
// Read users that are running alkami services.
var services = ServiceController.GetServices();
var alkamiServices = services.Where(service => service.ServiceName.ToLower().Contains("alkami"));
var users = alkamiServices.Select(service => GetServiceAccountUser(service))
.Where(user => user != null);
// Concatenate known users onto the end of the list of unqiue Alkami users.
users = users.Concat(new string[] { "IIS_IUSRS", microserviceUser, databaseUser });
// Determine nag/radium user by convention only if the gmsa prefix was passed in.
if (!string.IsNullOrWhiteSpace(gmsaPrefix))
{
users = users.Concat(new string[] { $"fh\\{gmsaPrefix}.nag$", $"fh\\{gmsaPrefix}.radium$" });
}
// Only return the unique usernames.
users = users.Distinct();
// Remove any LocalSystem users.
users = users.Where(user => !string.Equals(user, "LocalSystem", StringComparison.OrdinalIgnoreCase));
return users.ToArray();
}
/// <summary>
/// Returns the user of the ServiceController service.
/// </summary>
/// <param name="service"></param>
/// <returns></returns>
private string GetServiceAccountUser(ServiceController service)
{
try
{
System.Management.SelectQuery sQuery = new System.Management.SelectQuery($"select startname from Win32_Service where name = '{service.ServiceName}'");
using (System.Management.ManagementObjectSearcher mgmtSearcher = new System.Management.ManagementObjectSearcher(sQuery))
{
foreach (System.Management.ManagementObject manageObject in mgmtSearcher.Get())
{
string account = manageObject["Startname"].ToString();
return account;
}
}
}
catch (Exception e)
{
WriteWarning($"Failed to determine user account for service {service.ServiceName} with error:\n{e.ToString()}");
return null;
}
return null;
}
}
}