Source code for nf2.evaluation.output

import numpy as np
import torch
from astropy import units as u, constants
from astropy.coordinates import SkyCoord
from dateutil.parser import parse
from sunpy.coordinates import frames
from sunpy.map import Map
from torch import nn
from tqdm import tqdm

from nf2.data.util import spherical_to_cartesian, cartesian_to_spherical, vector_cartesian_to_spherical
from nf2.evaluation.energy import get_free_mag_energy
from nf2.evaluation.metric import energy
from nf2.evaluation.output_metrics import metric_mapping, normalize_metric_names
from nf2.train.model import VectorPotentialModel
from nf2.train.transform import HeightRangeTransformModel, AzimuthTransformModel, HeightTransformModel


class BaseOutput:
    """Base evaluator for NF2 checkpoints.

    Users normally construct geometry-specific helpers through :func:`nf2.load`
    instead of instantiating this class directly.
    """

    def __init__(self, checkpoint, device=None):
        if device is None:
            device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

        self.state = torch.load(checkpoint, map_location=device, weights_only=False)
        model = self.state['model']
        self._requires_grad = isinstance(model, VectorPotentialModel)
        self.model = nn.DataParallel(model) if torch.cuda.device_count() > 1 else model
        self.device = device
        self.c = constants.c

    @property
    def Gauss_per_dB(self):
        return self.state['data']['Gauss_per_dB'] * u.G

    @property
    def m_per_ds(self):
        return (self.state['data']['Mm_per_ds'] * u.Mm).to(u.m)

    def load_coords(self, coords, batch_size=int(2 ** 12), progress=False, compute_jacobian=True, metrics=None):
        """Evaluate the neural field at normalized model coordinates.

        Parameters
        ----------
        coords:
            Array with final dimension ``(x, y, z)`` in model coordinates.
        batch_size:
            Number of coordinates evaluated per model call.
        progress:
            Show a progress bar.
        compute_jacobian:
            Include the magnetic-field Jacobian in the output.
        metrics:
            Optional metric names from ``nf2.evaluation.output_metrics``.
        """
        batch_size = batch_size * torch.cuda.device_count() if torch.cuda.is_available() else batch_size
        metrics = normalize_metric_names(metrics)

        def _load(coords):
            # normalize and to tensor
            coords = torch.tensor(coords, dtype=torch.float32)
            coords_shape = coords.shape
            coords = coords.reshape((-1, 3))

            model_out = {}
            it = range(int(np.ceil(coords.shape[0] / batch_size)))
            it = tqdm(it, desc='Load NF2') if progress else it
            for k in it:
                self.model.zero_grad()
                coord = coords[k * batch_size: (k + 1) * batch_size]
                coord = coord.to(self.device)
                coord.requires_grad = True
                result = self.model(coord, compute_jacobian=compute_jacobian)
                for k, v in result.items():
                    if k not in model_out:
                        model_out[k] = []
                    model_out[k] += [v.detach().cpu()]

            model_out = {k: torch.cat(v) for k, v in model_out.items()}
            model_out = {k: v.reshape(*coords_shape[:-1], *v.shape[1:]).numpy() for k, v in model_out.items()}

            model_out['b'] = model_out['b'] * self.Gauss_per_dB
            if 'a' in model_out:
                model_out['a'] = model_out['a'] * self.Gauss_per_dB * self.m_per_ds
            if 'p' in model_out:
                model_out['p'] = model_out['p'] * self.Gauss_per_dB ** 2
            return model_out

        if self._requires_grad or compute_jacobian:
            model_out = _load(coords)

            if compute_jacobian:
                jac_matrix = model_out['jac_matrix']
                jac_matrix = jac_matrix * self.Gauss_per_dB / self.m_per_ds
                model_out['jac_matrix'] = jac_matrix
        else:
            with torch.no_grad():
                model_out = _load(coords)

        state = {**model_out, 'coords': coords}
        metrics_out = {}
        for key in metrics:
            if key not in metric_mapping:
                valid_options = ', '.join(sorted(metric_mapping))
                raise ValueError(f"Unknown output metric '{key}'. Valid options: {valid_options}")
            metric_out = metric_mapping[key](**state)
            metrics_out.update(metric_out)
            state.update(metric_out)

        model_out['metrics'] = metrics_out
        return model_out


[docs] class CartesianOutput(BaseOutput): """Evaluate Cartesian NF2 extrapolation checkpoints.""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) assert self.state['data']['type'] == 'cartesian', 'Requires cartesian NF2 data!' self.coord_range = self.state['data']['coord_range'] self.coord_range = self.coord_range[0] if isinstance(self.coord_range, list) else self.coord_range self.max_height = self.state['data']['max_height'] self.ds_per_pixel = self.state['data']['ds_per_pixel'] self.ds_per_pixel = self.ds_per_pixel[0] if isinstance(self.ds_per_pixel, list) else self.ds_per_pixel self.Mm_per_ds = self.state['data']['Mm_per_ds'] self.Mm_per_pixel = self.ds_per_pixel * self.Mm_per_ds self.wcs = [wcs for wcs in self.state['data']['wcs'] if wcs is not None] if 'wcs' in self.state[ 'data'] else None self.time = None if self.wcs is None or len(self.wcs) == 0 else parse(self.wcs[0].wcs.dateobs) self.data_config = self.state['data']
[docs] def load_cube(self, height_range=None, x_range=None, y_range=None, Mm_per_pixel=None, **kwargs): """Load a regularly sampled Cartesian volume. Ranges are specified in megameters. Additional keyword arguments are forwarded to :meth:`BaseOutput.load_coords`. """ x_min, x_max = self.coord_range[0] if x_range is None else np.array(x_range) / self.Mm_per_ds y_min, y_max = self.coord_range[1] if y_range is None else np.array(y_range) / self.Mm_per_ds z_min, z_max = (0, self.max_height / self.Mm_per_ds) if height_range is None \ else (h / self.Mm_per_ds for h in height_range) Mm_per_pixel = self.Mm_per_pixel if Mm_per_pixel is None else Mm_per_pixel ds_per_pixel = Mm_per_pixel / self.Mm_per_ds n_x_pix = np.round((x_max - x_min) / ds_per_pixel).astype(int) n_y_pix = np.round((y_max - y_min) / ds_per_pixel).astype(int) n_z_pix = np.round((z_max - z_min) / ds_per_pixel).astype(int) coords = np.stack(np.mgrid[:n_x_pix, :n_y_pix, :n_z_pix], -1) coords = coords * ds_per_pixel + np.array([x_min, y_min, z_min]).reshape((1, 1, 1, 3)) model_out = self.load_coords(coords, **kwargs) coords_Mm = coords / ds_per_pixel * Mm_per_pixel return {**model_out, 'coords': coords_Mm, 'Mm_per_pixel': Mm_per_pixel}
[docs] def load_slice(self, z=0 * u.Mm, Mm_per_pixel=None, **kwargs): """Load one horizontal Cartesian slice at height ``z``.""" x_min, x_max = self.coord_range[0] y_min, y_max = self.coord_range[1] Mm_per_pixel = self.Mm_per_pixel if Mm_per_pixel is None else Mm_per_pixel coords = np.stack( np.meshgrid(np.linspace(x_min, x_max, np.round((x_max - x_min) / self.ds_per_pixel + 1).astype(int)), np.linspace(y_min, y_max, np.round((y_max - y_min) / self.ds_per_pixel + 1).astype(int)), np.ones((1,), dtype=np.float32) * z.to_value(u.Mm) / self.Mm_per_pixel, indexing='ij'), -1) coords = coords[:, :, 0] model_out = self.load_coords(coords, **kwargs) return {**model_out, 'coords': coords, 'Mm_per_pixel': Mm_per_pixel}
[docs] def load_maps(self, **kwargs): """Load SunPy maps for integrated field strength, current, and energy.""" model_out = self.load_cube(**kwargs) j_map = np.linalg.norm(model_out['j'], axis=-1).sum(axis=-1) b_map = np.linalg.norm(model_out['b'], axis=-1).sum(axis=-1) energy_map = energy(model_out['b']).sum(axis=-1) free_energy_map = get_free_mag_energy(model_out['b']).sum(axis=-1) return {'b': Map(b_map, wcs=self.wcs), 'j': Map(j_map, wcs=self.wcs), 'energy': Map(energy_map, wcs=self.wcs), 'free_energy': Map(free_energy_map, wcs=self.wcs)}
def trace_bottom(self, Mm_per_pixel=None, **kwargs): x_min, x_max = self.coord_range[0] y_min, y_max = self.coord_range[1] Mm_per_pixel = self.Mm_per_pixel if Mm_per_pixel is None else Mm_per_pixel pixel_per_ds = self.Mm_per_ds / Mm_per_pixel coords = np.stack( np.meshgrid(np.linspace(x_min, x_max, int((x_max - x_min) * pixel_per_ds + 1)), np.linspace(y_min, y_max, int((y_max - y_min) * pixel_per_ds + 1)), np.zeros((1,), dtype=np.float32), indexing='ij'), -1) forward_trace = self.trace(coords, **kwargs) backward_trace = self.trace(coords, direction=-1, **kwargs) traces = {i: list(reversed(backward_trace[i][1:])) + forward_trace[i] for i in forward_trace.keys()} return traces
[docs] def trace(self, start_coords, direction=1, max_iterations=None, **kwargs): ''' Trace the field line from the given coordinates using a 4th order Runge-Kutta method. ''' base_step = np.array([self.ds_per_pixel / 4], dtype=np.float32).reshape((1, 1)) # quarter of a pixel x_min, x_max = self.coord_range[0] y_min, y_max = self.coord_range[1] z_min, z_max = (0, self.max_height / self.Mm_per_ds) max_iterations = ((x_max - x_min) + (y_max - y_min) + ( z_max - z_min)) / base_step * 3 if max_iterations is None else max_iterations field_lines = {i: [c] for i, c in enumerate(start_coords)} coords = start_coords iteration = 0 while np.isnan(coords).sum() != 0 and iteration < max_iterations: nan_mask = np.isnan(coords) # only trace non-nan coordinates coords_k1 = coords[~nan_mask] model_out_k1 = self.load_coords(coords_k1, **kwargs) b_k1 = model_out_k1['b'] k1 = b_k1 / np.linalg.norm(b_k1, axis=-1, keepdims=True) coords_k2 = coords_k1 + k1 * base_step / 2 # check this too model_out_k2 = self.load_coords(coords_k2, **kwargs) b_k2 = model_out_k2['b'] * np.sign(direction) k2 = b_k2 / np.linalg.norm(b_k2, axis=-1, keepdims=True) coords_k3 = coords_k1 + k2 * base_step / 2 model_out_k3 = self.load_coords(coords_k3, **kwargs) b_k3 = model_out_k3['b'] * np.sign(direction) k3 = b_k3 / np.linalg.norm(b_k3, axis=-1, keepdims=True) coords_k4 = coords_k1 + k3 * base_step model_out_k4 = self.load_coords(coords_k4, **kwargs) b_k4 = model_out_k4['b'] * np.sign(direction) k4 = b_k4 / np.linalg.norm(b_k4, axis=-1, keepdims=True) next_coords = coords_k1 + (k1 + 2 * k2 + 2 * k3 + k4) * base_step / 6 mask = (next_coords[..., 0] >= x_min) & (next_coords[..., 0] <= x_max) & \ (next_coords[..., 1] >= y_min) & (next_coords[..., 1] <= y_max) & \ (next_coords[..., 2] >= z_min) & (next_coords[..., 2] <= z_max) next_coords[~mask] = None # todo write final value at the boundary coords[~nan_mask] = next_coords for i, c in enumerate(coords): if c is None: continue field_lines[i].append(c) iteration += 1 return field_lines
class HeightTransformOutput(CartesianOutput): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.transforms = self.state['transforms'] height_transforms = [t for t in self.transforms if isinstance(t, HeightRangeTransformModel) or isinstance(t, HeightTransformModel)] assert len(height_transforms) == 1, 'Requires transform module!' self.transform_module = height_transforms[0] self.coord_range_list = self.state['data']['coord_range'] self.height_mapping_list = self.state['data']['height_mapping'] self.ds_per_pixel_list = self.state['data']['ds_per_pixel'] def load_height_mapping(self, Mm_per_pixel=None, **kwargs): mapping_out = [] for coord_range, height_mapping, ds_per_pixel in zip(self.coord_range_list, self.height_mapping_list, self.ds_per_pixel_list): if height_mapping is None: continue x_min, x_max = coord_range[0] y_min, y_max = coord_range[1] z = height_mapping['z'] / self.Mm_per_ds pixel_per_ds = self.Mm_per_ds / Mm_per_pixel if Mm_per_pixel is not None else 1 / ds_per_pixel coords = np.stack( np.meshgrid(np.linspace(x_min, x_max, int((x_max - x_min) * pixel_per_ds)), np.linspace(y_min, y_max, int((y_max - y_min) * pixel_per_ds)), z, indexing='ij'), -1) in_tensors = {'coords': coords} if 'z_min' in height_mapping and 'z_max' in height_mapping: height_range = np.zeros((*coords.shape[:-1], 2), dtype=np.float32) height_range[..., 0] = height_mapping['z_min'] / self.Mm_per_ds height_range[..., 1] = height_mapping['z_max'] / self.Mm_per_ds in_tensors['height_range'] = height_range in_tensors = {k: torch.tensor(v, dtype=torch.float32) for k, v in in_tensors.items()} model_out = self.load_transformed_coords(in_tensors, **kwargs) entry = {'height': z * self.Mm_per_ds, 'coords': model_out['coords'] * self.Mm_per_ds * u.Mm, 'original_coords': coords * self.Mm_per_ds * u.Mm, 'Mm_per_pixel': self.Mm_per_ds / pixel_per_ds} mapping_out.append(entry) return mapping_out @torch.no_grad() def load_transformed_coords(self, in_tensors, batch_size=int(2 ** 12), progress=False): cube_shape = list(in_tensors.values())[0].shape[:-1] flattened_tensors = {k: v.reshape((-1, v.shape[-1])) for k, v in in_tensors.items()} cube = {} it = range(int(np.ceil(list(flattened_tensors.values())[0].shape[0] / batch_size))) it = tqdm(it) if progress else it for i in it: self.transform_module.zero_grad() batch = {k: v[i * batch_size: (i + 1) * batch_size].to(self.device) for k, v in flattened_tensors.items()} transformed_coords = self.transform_module(batch) for k, v in transformed_coords.items(): if k not in cube: cube[k] = [] cube[k] += [v.detach().cpu()] cube = {k: torch.cat(v).reshape(*cube_shape, -1).numpy() for k, v in cube.items()} return cube class DisambiguationOutput(CartesianOutput): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.transforms = self.state['transforms'] disambiguation_transforms = [t for t in self.transforms if isinstance(t, AzimuthTransformModel)] assert len(disambiguation_transforms) == 1, 'Requires transform module!' self.transform_module = disambiguation_transforms[0] self.coord_range_list = self.state['data']['coord_range'] self.height_mapping_list = self.state['data']['height_mapping'] self.ds_per_pixel_list = self.state['data']['ds_per_pixel'] def load_slice(self, z=0 * u.Mm, coord_range=None, **kwargs): disambiguation = [] for coord_range, ds_per_pixel in zip(self.coord_range_list, self.ds_per_pixel_list): x_min, x_max = coord_range[0] y_min, y_max = coord_range[1] z = z.to_value(u.Mm) / self.Mm_per_ds pixel_per_ds = 1 / ds_per_pixel coords = np.stack( np.meshgrid(np.linspace(x_min, x_max, int((x_max - x_min) * pixel_per_ds)), np.linspace(y_min, y_max, int((y_max - y_min) * pixel_per_ds)), z, indexing='ij'), -1) model_out = self.load_transformed_coords(coords, **kwargs) entry = {'coords': coords * self.Mm_per_ds * u.Mm, 'flip': model_out['flip']} disambiguation.append(entry) return disambiguation def load_transformed_coords(self, coords, batch_size=int(2 ** 12), progress=False): def _load(coords): # normalize and to tensor coords = torch.tensor(coords, dtype=torch.float32) coords_shape = coords.shape coords = coords.reshape((-1, 3)) cube = {} it = range(int(np.ceil(coords.shape[0] / batch_size))) it = tqdm(it) if progress else it for k in it: self.transform_module.zero_grad() coord = coords[k * batch_size: (k + 1) * batch_size] coord = coord.to(self.device) transformed_coords = self.transform_module({'coords': coord}) for k, v in transformed_coords.items(): if k not in cube: cube[k] = [] cube[k] += [v.detach().cpu()] cube = {k: torch.cat(v).reshape(*coords_shape[:-1]).numpy() for k, v in cube.items()} return cube with torch.no_grad(): return _load(coords)
[docs] class SphericalOutput(BaseOutput): """Evaluate spherical NF2 extrapolation checkpoints.""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) assert self.state['data']['type'] == 'spherical', 'Requires spherical NF2 data!' self.radius_range = self.state['data']['radius_range'] if not hasattr(self.radius_range, 'unit'): self.radius_range = self.radius_range * u.solRad
[docs] def load_spherical(self, radius_range: u.Quantity = None, latitude_range: u.Quantity = (-np.pi / 2, np.pi / 2) * u.rad, longitude_range: u.Quantity = (0, 2 * np.pi) * u.rad, sampling=[100, 180, 360], **kwargs): """Load a regularly sampled spherical volume. ``sampling`` is ordered as radius, latitude, longitude. Ranges should use Astropy units. """ radius_range = radius_range if radius_range is not None else self.radius_range colatitude_range = sorted(np.pi / 2 * u.rad - latitude_range) spherical_coords = np.stack( np.meshgrid( np.linspace(radius_range[0].to_value(u.solRad), radius_range[1].to_value(u.solRad), sampling[0]), np.linspace(colatitude_range[0].to_value(u.rad), colatitude_range[1].to_value(u.rad), sampling[1]), np.linspace(longitude_range[0].to_value(u.rad), longitude_range[1].to_value(u.rad), sampling[2]), indexing='ij'), -1) cartesian_coords = spherical_to_cartesian(spherical_coords) scaled_coords = cartesian_coords * (1 * u.solRad / self.m_per_ds).to_value(u.dimensionless_unscaled) model_out = self.load_coords(scaled_coords, **kwargs) return {**model_out, 'coords': cartesian_coords, 'spherical_coords': spherical_coords}
[docs] def load(self, radius_range: u.Quantity = None, latitude_range: u.Quantity = (-np.pi / 2, np.pi / 2) * u.rad, longitude_range: u.Quantity = (0, 2 * np.pi), resolution: u.Quantity = 64 * u.pix / u.solRad, nan_value=0, **kwargs): """Load a Cartesian cube covering a spherical shell selection.""" radius_range = radius_range if radius_range is not None else self.radius_range # convert latitude to colatitude latitude_range = sorted(np.pi / 2 * u.rad - latitude_range) spherical_bounds = np.stack( np.meshgrid(np.linspace(radius_range[0].to_value(u.solRad), radius_range[1].to_value(u.solRad), 50), np.linspace(latitude_range[0].to_value(u.rad), latitude_range[1].to_value(u.rad), 50), np.linspace(longitude_range[0].to_value(u.rad), longitude_range[1].to_value(u.rad), 50), indexing='ij'), -1) cartesian_bounds = spherical_to_cartesian(spherical_bounds) x_min, x_max = cartesian_bounds[..., 0].min(), cartesian_bounds[..., 0].max() y_min, y_max = cartesian_bounds[..., 1].min(), cartesian_bounds[..., 1].max() z_min, z_max = cartesian_bounds[..., 2].min(), cartesian_bounds[..., 2].max() res = resolution.to_value(u.pix / u.solRad) coords = np.stack( np.meshgrid(np.linspace(x_min, x_max, int((x_max - x_min) * res)), np.linspace(y_min, y_max, int((y_max - y_min) * res)), np.linspace(z_min, z_max, int((z_max - z_min) * res)), indexing='ij'), -1) # flipped z axis spherical_coords = cartesian_to_spherical(coords) colatitude_coord = spherical_coords[..., 1] lon_coord = (spherical_coords[..., 2] % (2 * np.pi)) rad_coord = spherical_coords[..., 0] min_colatitude, max_colatitude = latitude_range[0].to_value(u.rad), latitude_range[1].to_value(u.rad) min_lon, max_lon = (longitude_range[0].to_value(u.rad), longitude_range[1].to_value(u.rad)) # only evaluate coordinates in simulation volume if min_colatitude == max_colatitude: lat_cond = np.ones_like(colatitude_coord, dtype=bool) else: lat_cond = (colatitude_coord >= min_colatitude) & (colatitude_coord < max_colatitude) if min_lon == max_lon: lon_cond = np.ones_like(lon_coord, dtype=bool) else: lon_cond = (lon_coord >= min_lon) & (lon_coord < max_lon) if max_lon > 2 * np.pi: lon_cond = lon_cond | ((lon_coord < max_lon - 2 * np.pi) & (lon_coord >= 0)) rad_cond = (rad_coord >= radius_range[0].to_value(u.solRad)) & (rad_coord < radius_range[1].to_value(u.solRad)) condition = rad_cond & lat_cond & lon_cond scaled_coords = coords * (1 * u.solRad / self.m_per_ds).to_value(u.dimensionless_unscaled) sub_coords = scaled_coords[condition] cube_shape = scaled_coords.shape[:-1] model_out = self.load_coords(sub_coords, **kwargs) spherical_out = {'spherical_coords': spherical_coords, 'coords': coords, 'metrics': {}} metrics = model_out.pop('metrics') for k, sub_v in metrics.items(): volume = np.ones(cube_shape + sub_v.shape[1:]) * nan_value if hasattr(sub_v, 'unit'): # preserve units volume = volume * sub_v.unit volume[condition] = sub_v spherical_out['metrics'][k] = volume for k, sub_v in model_out.items(): volume = np.ones(cube_shape + sub_v.shape[1:]) * nan_value if hasattr(sub_v, 'unit'): # preserve units volume = volume * sub_v.unit volume[condition] = sub_v spherical_out[k] = volume spherical_out['b_rtp'] = vector_cartesian_to_spherical(spherical_out['b'], spherical_coords) return spherical_out
[docs] def load_spherical_coords(self, spherical_coords: SkyCoord, **kwargs): """Evaluate the model at explicit SkyCoord positions.""" cartesian_coords, spherical_coords = self._skycoords_to_cartesian(spherical_coords) scaled_coords = cartesian_coords * (1 * u.solRad / self.m_per_ds).to_value(u.dimensionless_unscaled) model_out = self.load_coords(scaled_coords, **kwargs) model_out['spherical_coords'] = spherical_coords model_out['coords'] = cartesian_coords return model_out
def _skycoords_to_cartesian(self, spherical_coords): spherical_coords = spherical_coords.transform_to(frames.HeliographicCarrington) r = spherical_coords.radius r = r * u.solRad if r.unit == u.dimensionless_unscaled else r spherical_coords = np.stack([ r.to(u.solRad).value, np.pi / 2 - spherical_coords.lat.to(u.rad).value, spherical_coords.lon.to(u.rad).value, ], -1) cartesian_coords = spherical_to_cartesian(spherical_coords) return cartesian_coords, spherical_coords