Source code for nf2.export

"""Unified export command for NF2 result files."""

from __future__ import annotations

import argparse
import glob
import os
from pathlib import Path


def _geometry(nf2_path: str) -> str | None:
    import torch

    state = torch.load(nf2_path, map_location="cpu", weights_only=False)
    return state.get("data", {}).get("type")


def _normalize_format(fmt: str) -> str:
    fmt = fmt.lower().replace("-", "_")
    if fmt in {"height", "height_npz", "heights", "heights_npz"}:
        return "height"
    return fmt


[docs] def export_file( nf2_path: str, out_path: str | None = None, *, fmt: str = "vtk", Mm_per_pixel: float | None = None, height_range: list[float] | None = None, metrics: list[str] | None = None, x_range: list[float] | None = None, y_range: list[float] | None = None, radius_range: list[float] | None = None, latitude_range: list[float] | None = None, longitude_range: list[float] | None = None, pixels_per_solRad: int = 64, progress: bool = True, ): """Export one NF2 file to a supported exchange format. Parameters ---------- nf2_path: Path to an ``extrapolation_result.nf2`` checkpoint. out_path: Optional output path. Converter defaults are used when omitted. fmt: Export format: ``vtk``, ``npz``, ``hdf5``/``h5``, ``fits``, or ``height`` for multi-height surface mappings. Mm_per_pixel, height_range, x_range, y_range: Cartesian sampling controls in megameters. radius_range, latitude_range, longitude_range, pixels_per_solRad: Spherical VTK sampling controls. Radius is in solar radii; angles are in degrees. metrics: Derived quantities to include, such as ``j``, ``alpha``, or ``free_energy_fft``. progress: Show converter progress where supported. """ fmt = _normalize_format(fmt) metrics = metrics if metrics is not None else ["j"] geometry = _geometry(nf2_path) if geometry == "spherical" and fmt != "vtk": raise ValueError("Spherical checkpoints currently support `vtk` export only.") if fmt == "vtk": if geometry == "spherical": from nf2.convert.nf2_to_vtk_spherical import convert return convert( nf2_path=nf2_path, out_path=out_path, metrics=metrics, radius_range=radius_range, latitude_range=latitude_range, longitude_range=longitude_range, pixels_per_solRad=pixels_per_solRad, progress=progress, ) from nf2.convert.nf2_to_vtk import convert return convert( nf2_path=nf2_path, out_path=out_path, Mm_per_pixel=Mm_per_pixel, height_range=height_range, metrics=metrics, x_range=x_range, y_range=y_range, progress=progress, ) if fmt in {"npz", "npy"}: from nf2.convert.nf2_to_npz import convert return convert( nf2_path=nf2_path, out_path=out_path, Mm_per_pixel=Mm_per_pixel, height_range=height_range, metrics=metrics, x_range=x_range, y_range=y_range, progress=progress, ) if fmt == "height": from nf2.convert.nf2_height_to_npz import convert return convert( nf2_path=nf2_path, out_path=out_path, Mm_per_pixel=Mm_per_pixel, progress=progress, ) if fmt in {"hdf5", "h5"}: from nf2.convert.nf2_to_hdf5 import convert return convert( nf2_path=nf2_path, out_path=out_path, Mm_per_pixel=Mm_per_pixel, height_range=height_range, metrics=metrics, progress=progress, ) if fmt == "fits": from nf2.convert.nf2_to_fits import convert return convert( nf2_path=nf2_path, out_path=out_path, Mm_per_pixel=Mm_per_pixel, height_range=height_range, metrics=metrics, progress=progress, ) raise ValueError(f"Unsupported export format: {fmt}")
[docs] def export_series( patterns: list[str], out_dir: str, *, fmt: str = "vtk", overwrite: bool = False, **kwargs, ): """Export all NF2 files matched by one or more glob patterns. Existing files are skipped unless ``overwrite`` is true. """ fmt = _normalize_format(fmt) suffix = { "vtk": ".vtk", "npz": ".npz", "npy": ".npz", "height": ".height.npz", "hdf5": ".hdf5", "h5": ".hdf5", "fits": ".fits", }[fmt] nf2_paths = [path for pattern in patterns for path in sorted(glob.glob(pattern))] os.makedirs(out_dir, exist_ok=True) outputs = [] for nf2_path in nf2_paths: out_path = os.path.join(out_dir, Path(nf2_path).with_suffix(suffix).name) if os.path.exists(out_path) and not overwrite: continue outputs.append(export_file(nf2_path, out_path, fmt=fmt, progress=False, **kwargs)) return outputs
def main(): from nf2.evaluation.output_metrics import metric_mapping parser = argparse.ArgumentParser(description="Export NF2 extrapolation results.") parser.add_argument("nf2_path", nargs="+", help="NF2 file path or glob pattern") parser.add_argument( "--format", choices=["vtk", "npz", "hdf5", "h5", "fits", "height", "height-npz", "height_npz"], default="vtk", ) parser.add_argument("--out", help="Output file for a single input") parser.add_argument("--out-dir", help="Output directory for multiple inputs") parser.add_argument("--metrics", nargs="*", default=["j"], choices=sorted(metric_mapping)) parser.add_argument("--overwrite", action="store_true") cartesian = parser.add_argument_group("Cartesian sampling") cartesian.add_argument("--Mm_per_pixel", type=float, default=None) cartesian.add_argument("--height_range", type=float, nargs=2, default=None) cartesian.add_argument("--x_range", type=float, nargs=2, default=None) cartesian.add_argument("--y_range", type=float, nargs=2, default=None) spherical = parser.add_argument_group("Spherical VTK sampling") spherical.add_argument("--radius_range", type=float, nargs=2, default=None) spherical.add_argument("--latitude_range", type=float, nargs=2, default=None) spherical.add_argument("--longitude_range", type=float, nargs=2, default=None) spherical.add_argument("--pixels_per_solRad", type=int, default=64) args = parser.parse_args() matched = [path for pattern in args.nf2_path for path in sorted(glob.glob(pattern))] if len(matched) == 1 and args.out_dir is None: export_file( matched[0], args.out, fmt=args.format, Mm_per_pixel=args.Mm_per_pixel, height_range=args.height_range, metrics=args.metrics, x_range=args.x_range, y_range=args.y_range, radius_range=args.radius_range, latitude_range=args.latitude_range, longitude_range=args.longitude_range, pixels_per_solRad=args.pixels_per_solRad, ) return out_dir = args.out_dir if args.out_dir is not None else os.getcwd() export_series( args.nf2_path, out_dir, fmt=args.format, overwrite=args.overwrite, Mm_per_pixel=args.Mm_per_pixel, height_range=args.height_range, metrics=args.metrics, x_range=args.x_range, y_range=args.y_range, radius_range=args.radius_range, latitude_range=args.latitude_range, longitude_range=args.longitude_range, pixels_per_solRad=args.pixels_per_solRad, ) if __name__ == "__main__": main()