Source code for pyttb.import_data

"""Utilities for importing tensor data."""

# Copyright 2025 National Technology & Engineering Solutions of Sandia,
# LLC (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the
# U.S. Government retains certain rights in this software.

from __future__ import annotations

import os
from typing import TextIO, Tuple, Union

import numpy as np

import pyttb as ttb


[docs] def import_data( filename: str, index_base: int = 1 ) -> Union[ttb.sptensor, ttb.ktensor, ttb.tensor, np.ndarray]: """Import tensor data. Parameters ---------- filename: File to import. index_base: Index basing allows interoperability (Primarily between python and MATLAB). """ # Check if file exists if not os.path.isfile(filename): assert False, f"File path {filename} does not exist." # import with open(filename, "r") as fp: # tensor type should be on the first line # valid: tensor, sptensor, matrix, ktensor data_type = import_type(fp) if data_type not in ["tensor", "sptensor", "matrix", "ktensor"]: assert False, f"Invalid data type found: {data_type}" if data_type == "tensor": shape = import_shape(fp) data = import_array(fp, np.prod(shape)) return ttb.tensor(data, shape, copy=False) if data_type == "sptensor": shape = import_shape(fp) nz = import_nnz(fp) subs, vals = import_sparse_array(fp, len(shape), nz, index_base) return ttb.sptensor(subs, vals, shape) if data_type == "matrix": shape = import_shape(fp) mat = import_array(fp, np.prod(shape)) mat = np.reshape(mat, np.array(shape)) return mat if data_type == "ktensor": shape = import_shape(fp) r = import_rank(fp) weights = import_array(fp, r) factor_matrices = [] for _ in range(len(shape)): fp.readline().strip() # Skip factor type fac_shape = import_shape(fp) fac = import_array(fp, np.prod(fac_shape)) fac = np.reshape(fac, np.array(fac_shape)) factor_matrices.append(fac) return ttb.ktensor(factor_matrices, weights, copy=False) raise ValueError("Failed to load tensor data") # pragma: no cover
def import_type(fp: TextIO) -> str: """Extract IO data type.""" return fp.readline().strip().split(" ")[0] def import_shape(fp: TextIO) -> Tuple[int, ...]: """Extract the shape of something from a file.""" n = int(fp.readline().strip().split(" ")[0]) shape = [int(d) for d in fp.readline().strip().split(" ")] if len(shape) != n: assert False, "Imported dimensions are not of expected size" return tuple(shape) def import_nnz(fp: TextIO) -> int: """Extract the number of non-zeros of something from a file.""" return int(fp.readline().strip().split(" ")[0]) def import_rank(fp: TextIO) -> int: """Extract the rank of something from a file.""" return int(fp.readline().strip().split(" ")[0]) def import_sparse_array( fp: TextIO, n: int, nz: int, index_base: int = 1 ) -> Tuple[np.ndarray, np.ndarray]: """Extract sparse data subs and vals from coordinate format data.""" subs = np.zeros((nz, n), dtype="int64") vals = np.zeros((nz, 1)) for k in range(nz): line = fp.readline().strip().split(" ") subs[k, :] = [np.int64(i) - index_base for i in line[:-1]] vals[k, 0] = line[-1] return subs, vals def import_array(fp: TextIO, n: Union[int, np.integer]) -> np.ndarray: """Extract numpy array from file.""" return np.fromfile(fp, count=n, sep=" ")